Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fa2fbed3
Unverified
Commit
fa2fbed3
authored
May 20, 2020
by
Julien Plu
Committed by
GitHub
May 20, 2020
Browse files
Better None gradients handling in TF Trainer (#4469)
* Better None gradients handling * Apply Style * Apply Style
parent
e708bb75
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
12 deletions
+6
-12
src/transformers/trainer_tf.py
src/transformers/trainer_tf.py
+6
-12
No files found.
src/transformers/trainer_tf.py
View file @
fa2fbed3
...
@@ -141,7 +141,7 @@ class TFTrainer:
...
@@ -141,7 +141,7 @@ class TFTrainer:
self
.
optimizer
=
tf
.
keras
.
optimizers
.
get
(
self
.
optimizer
=
tf
.
keras
.
optimizers
.
get
(
{
"class_name"
:
self
.
args
.
optimizer_name
,
"config"
:
{
"learning_rate"
:
self
.
args
.
learning_rate
}}
{
"class_name"
:
self
.
args
.
optimizer_name
,
"config"
:
{
"learning_rate"
:
self
.
args
.
learning_rate
}}
)
)
logger
.
info
(
"Created an/a {} optimizer"
.
format
(
self
.
optimizer
))
logger
.
info
(
"Created an/a {} optimizer"
.
format
(
self
.
args
.
optimizer
_name
))
def
_create_checkpoint_manager
(
self
,
max_to_keep
:
int
=
5
,
load_model
:
bool
=
True
)
->
None
:
def
_create_checkpoint_manager
(
self
,
max_to_keep
:
int
=
5
,
load_model
:
bool
=
True
)
->
None
:
"""
"""
...
@@ -335,12 +335,8 @@ class TFTrainer:
...
@@ -335,12 +335,8 @@ class TFTrainer:
gradient
/
tf
.
cast
(
gradient_scale
,
gradient
.
dtype
)
for
gradient
in
self
.
gradient_accumulator
.
gradients
gradient
/
tf
.
cast
(
gradient_scale
,
gradient
.
dtype
)
for
gradient
in
self
.
gradient_accumulator
.
gradients
]
]
gradients
=
[(
tf
.
clip_by_value
(
grad
,
-
self
.
args
.
max_grad_norm
,
self
.
args
.
max_grad_norm
))
for
grad
in
gradients
]
gradients
=
[(
tf
.
clip_by_value
(
grad
,
-
self
.
args
.
max_grad_norm
,
self
.
args
.
max_grad_norm
))
for
grad
in
gradients
]
vars
=
self
.
model
.
trainable_variables
if
self
.
args
.
mode
in
[
"token-classification"
,
"question-answering"
]:
self
.
optimizer
.
apply_gradients
(
list
(
zip
(
gradients
,
self
.
model
.
trainable_variables
)))
vars
=
[
var
for
var
in
self
.
model
.
trainable_variables
if
"pooler"
not
in
var
.
name
]
self
.
optimizer
.
apply_gradients
(
list
(
zip
(
gradients
,
vars
)))
self
.
gradient_accumulator
.
reset
()
self
.
gradient_accumulator
.
reset
()
def
_accumulate_next_gradients
(
self
):
def
_accumulate_next_gradients
(
self
):
...
@@ -375,12 +371,10 @@ class TFTrainer:
...
@@ -375,12 +371,10 @@ class TFTrainer:
def
_forward
(
self
,
features
,
labels
):
def
_forward
(
self
,
features
,
labels
):
"""Forwards a training example and accumulates the gradients."""
"""Forwards a training example and accumulates the gradients."""
per_example_loss
,
_
=
self
.
_run_model
(
features
,
labels
,
True
)
per_example_loss
,
_
=
self
.
_run_model
(
features
,
labels
,
True
)
vars
=
self
.
model
.
trainable_variables
gradients
=
tf
.
gradients
(
per_example_loss
,
self
.
model
.
trainable_variables
)
gradients
=
[
if
self
.
args
.
mode
in
[
"token-classification"
,
"question-answering"
]:
g
if
g
is
not
None
else
tf
.
zeros_like
(
v
)
for
g
,
v
in
zip
(
gradients
,
self
.
model
.
trainable_variables
)
vars
=
[
var
for
var
in
self
.
model
.
trainable_variables
if
"pooler"
not
in
var
.
name
]
]
gradients
=
self
.
optimizer
.
get_gradients
(
per_example_loss
,
vars
)
self
.
gradient_accumulator
(
gradients
)
self
.
gradient_accumulator
(
gradients
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment