Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6faf2832
Unverified
Commit
6faf2832
authored
Aug 23, 2022
by
Joao Gante
Committed by
GitHub
Aug 23, 2022
Browse files
CLI: Don't check the model head when there is no model head (#18733)
parent
43869808
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
2 deletions
+12
-2
src/transformers/commands/pt_to_tf.py
src/transformers/commands/pt_to_tf.py
+12
-2
No files found.
src/transformers/commands/pt_to_tf.py
View file @
6faf2832
...
...
@@ -286,7 +286,12 @@ class PTtoTFCommand(BaseTransformersCLICommand):
crossload_differences
=
self
.
find_pt_tf_differences
(
pt_outputs
,
tf_from_pt_outputs
)
output_differences
=
{
k
:
v
for
k
,
v
in
crossload_differences
.
items
()
if
"hidden"
not
in
k
}
hidden_differences
=
{
k
:
v
for
k
,
v
in
crossload_differences
.
items
()
if
"hidden"
in
k
}
max_crossload_output_diff
=
max
(
output_differences
.
values
())
if
len
(
output_differences
)
==
0
and
architectures
is
not
None
:
raise
ValueError
(
f
"Something went wrong -- the config file has architectures (
{
architectures
}
), but no model head"
" output was found. All outputs start with 'hidden'"
)
max_crossload_output_diff
=
max
(
output_differences
.
values
())
if
output_differences
else
0.0
max_crossload_hidden_diff
=
max
(
hidden_differences
.
values
())
if
max_crossload_output_diff
>
MAX_ERROR
or
max_crossload_hidden_diff
>
self
.
_max_hidden_error
:
raise
ValueError
(
...
...
@@ -310,7 +315,12 @@ class PTtoTFCommand(BaseTransformersCLICommand):
conversion_differences
=
self
.
find_pt_tf_differences
(
pt_outputs
,
tf_outputs
)
output_differences
=
{
k
:
v
for
k
,
v
in
conversion_differences
.
items
()
if
"hidden"
not
in
k
}
hidden_differences
=
{
k
:
v
for
k
,
v
in
conversion_differences
.
items
()
if
"hidden"
in
k
}
max_conversion_output_diff
=
max
(
output_differences
.
values
())
if
len
(
output_differences
)
==
0
and
architectures
is
not
None
:
raise
ValueError
(
f
"Something went wrong -- the config file has architectures (
{
architectures
}
), but no model head"
" output was found. All outputs start with 'hidden'"
)
max_conversion_output_diff
=
max
(
output_differences
.
values
())
if
output_differences
else
0.0
max_conversion_hidden_diff
=
max
(
hidden_differences
.
values
())
if
max_conversion_output_diff
>
MAX_ERROR
or
max_conversion_hidden_diff
>
self
.
_max_hidden_error
:
raise
ValueError
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment