Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
66e86567
Unverified
Commit
66e86567
authored
Jun 08, 2022
by
Joao Gante
Committed by
GitHub
Jun 08, 2022
Browse files
CLI: Print all different tensors on exception (#17612)
parent
e9d51387
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
28 deletions
+31
-28
src/transformers/commands/pt_to_tf.py
src/transformers/commands/pt_to_tf.py
+31
-28
No files found.
src/transformers/commands/pt_to_tf.py
View file @
66e86567
...
@@ -87,10 +87,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
...
@@ -87,10 +87,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
train_parser
.
set_defaults
(
func
=
convert_command_factory
)
train_parser
.
set_defaults
(
func
=
convert_command_factory
)
@
staticmethod
@
staticmethod
def
compare_pt_tf_model
s
(
pt_model
,
pt_input
,
tf_model
,
tf_input
):
def
find_pt_tf_difference
s
(
pt_model
,
pt_input
,
tf_model
,
tf_input
):
"""
"""
Compares the TensorFlow and PyTorch models, given their inputs, returning a
tuple with the maximum observed
Compares the TensorFlow and PyTorch models, given their inputs, returning a
dictionary with all tensor
difference
and its source
.
difference
s
.
"""
"""
pt_outputs
=
pt_model
(
**
pt_input
,
output_hidden_states
=
True
)
pt_outputs
=
pt_model
(
**
pt_input
,
output_hidden_states
=
True
)
tf_outputs
=
tf_model
(
**
tf_input
,
output_hidden_states
=
True
)
tf_outputs
=
tf_model
(
**
tf_input
,
output_hidden_states
=
True
)
...
@@ -104,18 +104,14 @@ class PTtoTFCommand(BaseTransformersCLICommand):
...
@@ -104,18 +104,14 @@ class PTtoTFCommand(BaseTransformersCLICommand):
f
"
{
tf_out_attrs
}
)"
f
"
{
tf_out_attrs
}
)"
)
)
# 2. For each output attribute, ALL values must be the same
# 2. For each output attribute, computes the difference
def
_compate_pt_tf_models
(
pt_out
,
tf_out
,
attr_name
=
""
):
def
_find_pt_tf_differences
(
pt_out
,
tf_out
,
differences
,
attr_name
=
""
):
max_difference
=
0
max_difference_source
=
""
# If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in
# If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in
# recursivelly, keeping the name of the attribute.
# recursivelly, keeping the name of the attribute.
if
isinstance
(
pt_out
,
(
torch
.
Tensor
)):
if
isinstance
(
pt_out
,
torch
.
Tensor
):
difference
=
np
.
max
(
np
.
abs
(
pt_out
.
detach
().
numpy
()
-
tf_out
.
numpy
()))
tensor_difference
=
np
.
max
(
np
.
abs
(
pt_out
.
detach
().
numpy
()
-
tf_out
.
numpy
()))
if
difference
>
max_difference
:
differences
[
attr_name
]
=
tensor_difference
max_difference
=
difference
max_difference_source
=
attr_name
else
:
else
:
root_name
=
attr_name
root_name
=
attr_name
for
i
,
pt_item
in
enumerate
(
pt_out
):
for
i
,
pt_item
in
enumerate
(
pt_out
):
...
@@ -127,14 +123,11 @@ class PTtoTFCommand(BaseTransformersCLICommand):
...
@@ -127,14 +123,11 @@ class PTtoTFCommand(BaseTransformersCLICommand):
else
:
else
:
branch_name
=
root_name
+
f
"[
{
i
}
]"
branch_name
=
root_name
+
f
"[
{
i
}
]"
tf_item
=
tf_out
[
i
]
tf_item
=
tf_out
[
i
]
difference
,
difference_source
=
_compate_pt_tf_models
(
pt_item
,
tf_item
,
branch_name
)
differences
=
_find_pt_tf_differences
(
pt_item
,
tf_item
,
differences
,
branch_name
)
if
difference
>
max_difference
:
max_difference
=
difference
max_difference_source
=
difference_source
return
max_
difference
,
max_difference_source
return
difference
s
return
_
compate_pt_tf_model
s
(
pt_outputs
,
tf_outputs
)
return
_
find_pt_tf_difference
s
(
pt_outputs
,
tf_outputs
,
{}
)
def
__init__
(
self
,
model_name
:
str
,
local_dir
:
str
,
no_pr
:
bool
,
new_weights
:
bool
,
*
args
):
def
__init__
(
self
,
model_name
:
str
,
local_dir
:
str
,
no_pr
:
bool
,
new_weights
:
bool
,
*
args
):
self
.
_logger
=
logging
.
get_logger
(
"transformers-cli/pt_to_tf"
)
self
.
_logger
=
logging
.
get_logger
(
"transformers-cli/pt_to_tf"
)
...
@@ -213,11 +206,15 @@ class PTtoTFCommand(BaseTransformersCLICommand):
...
@@ -213,11 +206,15 @@ class PTtoTFCommand(BaseTransformersCLICommand):
tf_input
.
update
({
"decoder_input_ids"
:
tf
.
convert_to_tensor
(
decoder_input_ids
)})
tf_input
.
update
({
"decoder_input_ids"
:
tf
.
convert_to_tensor
(
decoder_input_ids
)})
# Confirms that cross loading PT weights into TF worked.
# Confirms that cross loading PT weights into TF worked.
crossload_diff
,
diff_source
=
self
.
compare_pt_tf_models
(
pt_model
,
pt_input
,
tf_from_pt_model
,
tf_input
)
crossload_differences
=
self
.
find_pt_tf_differences
(
pt_model
,
pt_input
,
tf_from_pt_model
,
tf_input
)
if
crossload_diff
>=
MAX_ERROR
:
max_crossload_diff
=
max
(
crossload_differences
.
values
())
if
max_crossload_diff
>
MAX_ERROR
:
raise
ValueError
(
raise
ValueError
(
"The cross-loaded TF model has different outputs, something went wrong! (max difference ="
"The cross-loaded TensorFlow model has different outputs, something went wrong! Exaustive list of"
f
"
{
crossload_diff
:.
3
e
}
, observed in
{
diff_source
}
)"
f
" maximum tensor differences above the error threshold (
{
MAX_ERROR
}
):
\n
"
+
"
\n
"
.
join
(
[
f
"
{
key
}
:
{
value
:.
3
e
}
"
for
key
,
value
in
crossload_differences
.
items
()
if
value
>
MAX_ERROR
]
)
)
)
# Save the weights in a TF format (if needed) and confirms that the results are still good
# Save the weights in a TF format (if needed) and confirms that the results are still good
...
@@ -226,11 +223,15 @@ class PTtoTFCommand(BaseTransformersCLICommand):
...
@@ -226,11 +223,15 @@ class PTtoTFCommand(BaseTransformersCLICommand):
tf_from_pt_model
.
save_weights
(
tf_weights_path
)
tf_from_pt_model
.
save_weights
(
tf_weights_path
)
del
tf_from_pt_model
# will no longer be used, and may have a large memory footprint
del
tf_from_pt_model
# will no longer be used, and may have a large memory footprint
tf_model
=
tf_class
.
from_pretrained
(
self
.
_local_dir
)
tf_model
=
tf_class
.
from_pretrained
(
self
.
_local_dir
)
converted_diff
,
diff_source
=
self
.
compare_pt_tf_models
(
pt_model
,
pt_input
,
tf_model
,
tf_input
)
conversion_differences
=
self
.
find_pt_tf_differences
(
pt_model
,
pt_input
,
tf_model
,
tf_input
)
if
converted_diff
>=
MAX_ERROR
:
max_conversion_diff
=
max
(
conversion_differences
.
values
())
if
max_conversion_diff
>
MAX_ERROR
:
raise
ValueError
(
raise
ValueError
(
"The converted TF model has different outputs, something went wrong! (max difference ="
"The converted TensorFlow model has different outputs, something went wrong! Exaustive list of maximum"
f
"
{
converted_diff
:.
3
e
}
, observed in
{
diff_source
}
)"
f
" tensor differences above the error threshold (
{
MAX_ERROR
}
):
\n
"
+
"
\n
"
.
join
(
[
f
"
{
key
}
:
{
value
:.
3
e
}
"
for
key
,
value
in
conversion_differences
.
items
()
if
value
>
MAX_ERROR
]
)
)
)
if
not
self
.
_no_pr
:
if
not
self
.
_no_pr
:
...
@@ -245,8 +246,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
...
@@ -245,8 +246,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
create_pr
=
True
,
create_pr
=
True
,
pr_commit_summary
=
"Add TF weights"
,
pr_commit_summary
=
"Add TF weights"
,
pr_commit_description
=
(
pr_commit_description
=
(
f
"Validated by the `pt_to_tf` CLI. Max crossload output difference=
{
crossload_diff
:.
3
e
}
;"
"Model converted by the `transformers`' `pt_to_tf` CLI -- all converted model outputs and"
f
" Max converted output difference=
{
converted_diff
:.
3
e
}
."
" hidden layers were validated against its Pytorch counterpart. Maximum crossload output"
f
" difference=
{
max_crossload_diff
:.
3
e
}
; Maximum converted output"
f
" difference=
{
max_conversion_diff
:.
3
e
}
."
),
),
)
)
self
.
_logger
.
warn
(
f
"PR open in
{
hub_pr_url
}
"
)
self
.
_logger
.
warn
(
f
"PR open in
{
hub_pr_url
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment