Unverified Commit 7dfdf793 authored by Teven's avatar Teven Committed by GitHub
Browse files

Fixing case in which `Trainer` hung while saving model in distributed training (#7365)

* remote debugging

* remote debugging

* moved _store_flos call

* moved _store_flos call

* moved _store_flos call

* removed debugging artefacts
parent 0ccb6f5c
...@@ -812,6 +812,7 @@ class Trainer: ...@@ -812,6 +812,7 @@ class Trainer:
checkpoint_folder += f"-run-{run_id}" checkpoint_folder += f"-run-{run_id}"
output_dir = os.path.join(self.args.output_dir, checkpoint_folder) output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
self.store_flos()
self.save_model(output_dir) self.save_model(output_dir)
if self.is_world_process_zero(): if self.is_world_process_zero():
...@@ -1151,7 +1152,6 @@ class Trainer: ...@@ -1151,7 +1152,6 @@ class Trainer:
raise ValueError("Trainer.model appears to not be a PreTrainedModel") raise ValueError("Trainer.model appears to not be a PreTrainedModel")
xm.rendezvous("saving_checkpoint") xm.rendezvous("saving_checkpoint")
self._store_flos()
self.model.save_pretrained(output_dir) self.model.save_pretrained(output_dir)
if self.tokenizer is not None: if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir)
...@@ -1164,7 +1164,6 @@ class Trainer: ...@@ -1164,7 +1164,6 @@ class Trainer:
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel): if not isinstance(self.model, PreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel") raise ValueError("Trainer.model appears to not be a PreTrainedModel")
self._store_flos()
self.model.save_pretrained(output_dir) self.model.save_pretrained(output_dir)
if self.tokenizer is not None: if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir)
...@@ -1175,7 +1174,7 @@ class Trainer: ...@@ -1175,7 +1174,7 @@ class Trainer:
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
) )
def _store_flos(self): def store_flos(self):
# Storing the number of floating-point operations that went into the model # Storing the number of floating-point operations that went into the model
if self.total_flos is not None: if self.total_flos is not None:
if self.args.local_rank != -1: if self.args.local_rank != -1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment