"vscode:/vscode.git/clone" did not exist on "7f6d37502997226862a441dddfcdd7d247479772"
Unverified Commit 36dfc317 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

TF Checkpoints (#4831)

* Align checkpoint dir with the PT trainer

* Use args for max to keep checkpoints
parent 439f1cab
......@@ -230,8 +230,9 @@ class TFTrainer:
with self.args.strategy.scope():
optimizer, lr_scheduler = self.get_optimizers()
iterations = optimizer.iterations
folder = os.path.join(self.args.output_dir, PREFIX_CHECKPOINT_DIR)
ckpt = tf.train.Checkpoint(optimizer=optimizer, model=self.model)
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, PREFIX_CHECKPOINT_DIR, max_to_keep=5)
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=self.args.save_total_limit)
if self.model.ckpt_manager.latest_checkpoint:
logger.info(
......@@ -401,17 +402,12 @@ class TFTrainer:
def save_model(self, output_dir: Optional[str] = None):
"""
Save the pretrained model and create a Tensorflow saved model.
Save the pretrained model.
"""
output_dir = output_dir if output_dir is not None else self.args.output_dir
logger.info("Saving model in {}".format(output_dir))
path = os.path.join(self.args.output_dir, "saved_model")
logger.info("Saving model in {}".format(path))
os.makedirs(path, exist_ok=True)
if not isinstance(self.model, TFPreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment