Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
06a6a4bd
Unverified
Commit
06a6a4bd
authored
Aug 25, 2022
by
Joao Gante
Committed by
GitHub
Aug 25, 2022
Browse files
CLI: Improved error control and updated hub requirement (#18752)
parent
e9442440
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
18 deletions
+22
-18
src/transformers/commands/pt_to_tf.py
src/transformers/commands/pt_to_tf.py
+22
-18
No files found.
src/transformers/commands/pt_to_tf.py
View file @
06a6a4bd
...
@@ -59,7 +59,7 @@ def convert_command_factory(args: Namespace):
...
@@ -59,7 +59,7 @@ def convert_command_factory(args: Namespace):
return
PTtoTFCommand
(
return
PTtoTFCommand
(
args
.
model_name
,
args
.
model_name
,
args
.
local_dir
,
args
.
local_dir
,
args
.
max_
hidden_
error
,
args
.
max_error
,
args
.
new_weights
,
args
.
new_weights
,
args
.
no_pr
,
args
.
no_pr
,
args
.
push
,
args
.
push
,
...
@@ -96,12 +96,11 @@ class PTtoTFCommand(BaseTransformersCLICommand):
...
@@ -96,12 +96,11 @@ class PTtoTFCommand(BaseTransformersCLICommand):
help
=
"Optional local directory of the model repository. Defaults to /tmp/{model_name}"
,
help
=
"Optional local directory of the model repository. Defaults to /tmp/{model_name}"
,
)
)
train_parser
.
add_argument
(
train_parser
.
add_argument
(
"--max-
hidden-
error"
,
"--max-error"
,
type
=
float
,
type
=
float
,
default
=
MAX_ERROR
,
default
=
MAX_ERROR
,
help
=
(
help
=
(
f
"Maximum error tolerance for hidden layer outputs. Defaults to
{
MAX_ERROR
}
. If you suspect the hidden"
f
"Maximum error tolerance. Defaults to
{
MAX_ERROR
}
. This flag should be avoided, use at your own risk."
" layers outputs will be used for downstream applications, avoid increasing this tolerance."
),
),
)
)
train_parser
.
add_argument
(
train_parser
.
add_argument
(
...
@@ -168,7 +167,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
...
@@ -168,7 +167,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
self
,
self
,
model_name
:
str
,
model_name
:
str
,
local_dir
:
str
,
local_dir
:
str
,
max_
hidden_
error
:
float
,
max_error
:
float
,
new_weights
:
bool
,
new_weights
:
bool
,
no_pr
:
bool
,
no_pr
:
bool
,
push
:
bool
,
push
:
bool
,
...
@@ -178,7 +177,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
...
@@ -178,7 +177,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
self
.
_logger
=
logging
.
get_logger
(
"transformers-cli/pt_to_tf"
)
self
.
_logger
=
logging
.
get_logger
(
"transformers-cli/pt_to_tf"
)
self
.
_model_name
=
model_name
self
.
_model_name
=
model_name
self
.
_local_dir
=
local_dir
if
local_dir
else
os
.
path
.
join
(
"/tmp"
,
model_name
)
self
.
_local_dir
=
local_dir
if
local_dir
else
os
.
path
.
join
(
"/tmp"
,
model_name
)
self
.
_max_
hidden_
error
=
max_
hidden_
error
self
.
_max_error
=
max_error
self
.
_new_weights
=
new_weights
self
.
_new_weights
=
new_weights
self
.
_no_pr
=
no_pr
self
.
_no_pr
=
no_pr
self
.
_push
=
push
self
.
_push
=
push
...
@@ -239,9 +238,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
...
@@ -239,9 +238,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
return
pt_input
,
tf_input
return
pt_input
,
tf_input
def
run
(
self
):
def
run
(
self
):
if
version
.
parse
(
huggingface_hub
.
__version__
)
<
version
.
parse
(
"0.8.1"
):
# hub version 0.9.0 introduced the possibility of programmatically opening PRs with normal write tokens.
if
version
.
parse
(
huggingface_hub
.
__version__
)
<
version
.
parse
(
"0.9.0"
):
raise
ImportError
(
raise
ImportError
(
"The huggingface_hub version must be >= 0.
8.1
to use this command. Please update your huggingface_hub"
"The huggingface_hub version must be >= 0.
9.0
to use this command. Please update your huggingface_hub"
" installation."
" installation."
)
)
else
:
else
:
...
@@ -293,13 +293,13 @@ class PTtoTFCommand(BaseTransformersCLICommand):
...
@@ -293,13 +293,13 @@ class PTtoTFCommand(BaseTransformersCLICommand):
)
)
max_crossload_output_diff
=
max
(
output_differences
.
values
())
if
output_differences
else
0.0
max_crossload_output_diff
=
max
(
output_differences
.
values
())
if
output_differences
else
0.0
max_crossload_hidden_diff
=
max
(
hidden_differences
.
values
())
max_crossload_hidden_diff
=
max
(
hidden_differences
.
values
())
if
max_crossload_output_diff
>
MAX_ERROR
or
max_crossload_hidden_diff
>
self
.
_max_
hidden_
error
:
if
max_crossload_output_diff
>
self
.
_max_error
or
max_crossload_hidden_diff
>
self
.
_max_error
:
raise
ValueError
(
raise
ValueError
(
"The cross-loaded TensorFlow model has different outputs, something went wrong!
\n
"
"The cross-loaded TensorFlow model has different outputs, something went wrong!
\n
"
+
f
"
\n
List of maximum output differences above the threshold (
{
MAX_ERROR
}
):
\n
"
+
f
"
\n
List of maximum output differences above the threshold (
{
self
.
_max_error
}
):
\n
"
+
"
\n
"
.
join
([
f
"
{
k
}
:
{
v
:.
3
e
}
"
for
k
,
v
in
output_differences
.
items
()
if
v
>
MAX_ERROR
])
+
"
\n
"
.
join
([
f
"
{
k
}
:
{
v
:.
3
e
}
"
for
k
,
v
in
output_differences
.
items
()
if
v
>
self
.
_max_error
])
+
f
"
\n\n
List of maximum hidden layer differences above the threshold (
{
self
.
_max_
hidden_
error
}
):
\n
"
+
f
"
\n\n
List of maximum hidden layer differences above the threshold (
{
self
.
_max_error
}
):
\n
"
+
"
\n
"
.
join
([
f
"
{
k
}
:
{
v
:.
3
e
}
"
for
k
,
v
in
hidden_differences
.
items
()
if
v
>
self
.
_max_
hidden_
error
])
+
"
\n
"
.
join
([
f
"
{
k
}
:
{
v
:.
3
e
}
"
for
k
,
v
in
hidden_differences
.
items
()
if
v
>
self
.
_max_error
])
)
)
# Save the weights in a TF format (if needed) and confirms that the results are still good
# Save the weights in a TF format (if needed) and confirms that the results are still good
...
@@ -322,13 +322,13 @@ class PTtoTFCommand(BaseTransformersCLICommand):
...
@@ -322,13 +322,13 @@ class PTtoTFCommand(BaseTransformersCLICommand):
)
)
max_conversion_output_diff
=
max
(
output_differences
.
values
())
if
output_differences
else
0.0
max_conversion_output_diff
=
max
(
output_differences
.
values
())
if
output_differences
else
0.0
max_conversion_hidden_diff
=
max
(
hidden_differences
.
values
())
max_conversion_hidden_diff
=
max
(
hidden_differences
.
values
())
if
max_conversion_output_diff
>
MAX_ERROR
or
max_conversion_hidden_diff
>
self
.
_max_
hidden_
error
:
if
max_conversion_output_diff
>
self
.
_max_error
or
max_conversion_hidden_diff
>
self
.
_max_error
:
raise
ValueError
(
raise
ValueError
(
"The converted TensorFlow model has different outputs, something went wrong!
\n
"
"The converted TensorFlow model has different outputs, something went wrong!
\n
"
+
f
"
\n
List of maximum output differences above the threshold (
{
MAX_ERROR
}
):
\n
"
+
f
"
\n
List of maximum output differences above the threshold (
{
self
.
_max_error
}
):
\n
"
+
"
\n
"
.
join
([
f
"
{
k
}
:
{
v
:.
3
e
}
"
for
k
,
v
in
output_differences
.
items
()
if
v
>
MAX_ERROR
])
+
"
\n
"
.
join
([
f
"
{
k
}
:
{
v
:.
3
e
}
"
for
k
,
v
in
output_differences
.
items
()
if
v
>
self
.
_max_error
])
+
f
"
\n\n
List of maximum hidden layer differences above the threshold (
{
self
.
_max_
hidden_
error
}
):
\n
"
+
f
"
\n\n
List of maximum hidden layer differences above the threshold (
{
self
.
_max_error
}
):
\n
"
+
"
\n
"
.
join
([
f
"
{
k
}
:
{
v
:.
3
e
}
"
for
k
,
v
in
hidden_differences
.
items
()
if
v
>
self
.
_max_
hidden_
error
])
+
"
\n
"
.
join
([
f
"
{
k
}
:
{
v
:.
3
e
}
"
for
k
,
v
in
hidden_differences
.
items
()
if
v
>
self
.
_max_error
])
)
)
commit_message
=
"Update TF weights"
if
self
.
_new_weights
else
"Add TF weights"
commit_message
=
"Update TF weights"
if
self
.
_new_weights
else
"Add TF weights"
...
@@ -348,6 +348,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
...
@@ -348,6 +348,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
f
"Maximum conversion output difference=
{
max_conversion_output_diff
:.
3
e
}
; "
f
"Maximum conversion output difference=
{
max_conversion_output_diff
:.
3
e
}
; "
f
"Maximum conversion hidden layer difference=
{
max_conversion_hidden_diff
:.
3
e
}
;
\n
"
f
"Maximum conversion hidden layer difference=
{
max_conversion_hidden_diff
:.
3
e
}
;
\n
"
)
)
if
self
.
_max_error
>
MAX_ERROR
:
commit_descrition
+=
(
f
"
\n\n
CAUTION: The maximum admissible error was manually increased to
{
self
.
_max_error
}
!"
)
if
self
.
_extra_commit_description
:
if
self
.
_extra_commit_description
:
commit_descrition
+=
"
\n\n
"
+
self
.
_extra_commit_description
commit_descrition
+=
"
\n\n
"
+
self
.
_extra_commit_description
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment