Unverified Commit fa2fbed3 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Better None gradients handling in TF Trainer (#4469)

* Better None gradients handling

* Apply Style

* Apply Style
parent e708bb75
......@@ -141,7 +141,7 @@ class TFTrainer:
self.optimizer = tf.keras.optimizers.get(
{"class_name": self.args.optimizer_name, "config": {"learning_rate": self.args.learning_rate}}
)
logger.info("Created an/a {} optimizer".format(self.optimizer))
logger.info("Created an/a {} optimizer".format(self.args.optimizer_name))
def _create_checkpoint_manager(self, max_to_keep: int = 5, load_model: bool = True) -> None:
"""
......@@ -335,12 +335,8 @@ class TFTrainer:
gradient / tf.cast(gradient_scale, gradient.dtype) for gradient in self.gradient_accumulator.gradients
]
gradients = [(tf.clip_by_value(grad, -self.args.max_grad_norm, self.args.max_grad_norm)) for grad in gradients]
vars = self.model.trainable_variables
if self.args.mode in ["token-classification", "question-answering"]:
vars = [var for var in self.model.trainable_variables if "pooler" not in var.name]
self.optimizer.apply_gradients(list(zip(gradients, vars)))
self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
self.gradient_accumulator.reset()
def _accumulate_next_gradients(self):
......@@ -375,12 +371,10 @@ class TFTrainer:
def _forward(self, features, labels):
"""Forwards a training example and accumulates the gradients."""
per_example_loss, _ = self._run_model(features, labels, True)
vars = self.model.trainable_variables
if self.args.mode in ["token-classification", "question-answering"]:
vars = [var for var in self.model.trainable_variables if "pooler" not in var.name]
gradients = self.optimizer.get_gradients(per_example_loss, vars)
gradients = tf.gradients(per_example_loss, self.model.trainable_variables)
gradients = [
g if g is not None else tf.zeros_like(v) for g, v in zip(gradients, self.model.trainable_variables)
]
self.gradient_accumulator(gradients)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment