Unverified Commit f284aa32 authored by Merve Noyan's avatar Merve Noyan Committed by GitHub
Browse files

steps strategy fix for PushtoHubCallback (#16138)

parent e3645fd2
...@@ -264,7 +264,7 @@ class PushToHubCallback(Callback): ...@@ -264,7 +264,7 @@ class PushToHubCallback(Callback):
save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"epoch"`): save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"epoch"`):
The checkpoint save strategy to adopt during training. Possible values are: The checkpoint save strategy to adopt during training. Possible values are:
- `"no"`: No save is done during training. - `"no"`: Save is done at the end of training.
- `"epoch"`: Save is done at the end of each epoch. - `"epoch"`: Save is done at the end of each epoch.
- `"steps"`: Save is done every `save_steps` - `"steps"`: Save is done every `save_steps`
save_steps (`int`, *optional*): save_steps (`int`, *optional*):
...@@ -331,7 +331,7 @@ class PushToHubCallback(Callback): ...@@ -331,7 +331,7 @@ class PushToHubCallback(Callback):
self.training_history = [] self.training_history = []
def on_train_batch_end(self, batch, logs=None): def on_train_batch_end(self, batch, logs=None):
if self.save_strategy == IntervalStrategy.STEPS and batch + 1 % self.save_steps == 0: if self.save_strategy == IntervalStrategy.STEPS and (batch + 1) % self.save_steps == 0:
if self.last_job is not None and not self.last_job.is_done: if self.last_job is not None and not self.last_job.is_done:
return # The last upload is still running, don't start another return # The last upload is still running, don't start another
self.model.save_pretrained(self.output_dir) self.model.save_pretrained(self.output_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment