Unverified Commit 36dfc317 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

TF Checkpoints (#4831)

* Align checkpoint dir with the PT trainer

* Use args for max to keep checkpoints
parent 439f1cab
...@@ -230,8 +230,9 @@ class TFTrainer: ...@@ -230,8 +230,9 @@ class TFTrainer:
with self.args.strategy.scope(): with self.args.strategy.scope():
optimizer, lr_scheduler = self.get_optimizers() optimizer, lr_scheduler = self.get_optimizers()
iterations = optimizer.iterations iterations = optimizer.iterations
folder = os.path.join(self.args.output_dir, PREFIX_CHECKPOINT_DIR)
ckpt = tf.train.Checkpoint(optimizer=optimizer, model=self.model) ckpt = tf.train.Checkpoint(optimizer=optimizer, model=self.model)
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, PREFIX_CHECKPOINT_DIR, max_to_keep=5) self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=self.args.save_total_limit)
if self.model.ckpt_manager.latest_checkpoint: if self.model.ckpt_manager.latest_checkpoint:
logger.info( logger.info(
...@@ -401,17 +402,12 @@ class TFTrainer: ...@@ -401,17 +402,12 @@ class TFTrainer:
def save_model(self, output_dir: Optional[str] = None): def save_model(self, output_dir: Optional[str] = None):
""" """
Save the pretrained model and create a Tensorflow saved model. Save the pretrained model.
""" """
output_dir = output_dir if output_dir is not None else self.args.output_dir output_dir = output_dir if output_dir is not None else self.args.output_dir
logger.info("Saving model in {}".format(output_dir)) logger.info("Saving model in {}".format(output_dir))
path = os.path.join(self.args.output_dir, "saved_model")
logger.info("Saving model in {}".format(path))
os.makedirs(path, exist_ok=True)
if not isinstance(self.model, TFPreTrainedModel): if not isinstance(self.model, TFPreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel") raise ValueError("Trainer.model appears to not be a PreTrainedModel")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment