Unverified Commit 7c1b9128 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Index RNG states by global rank in saves (#17852)

parent 7cf52a49
...@@ -1922,12 +1922,12 @@ class Trainer: ...@@ -1922,12 +1922,12 @@ class Trainer:
if checkpoint is None: if checkpoint is None:
return return
local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank if self.args.world_size > 1:
if local_rank != -1: process_index = self.args.process_index
rng_file = os.path.join(checkpoint, f"rng_state_{local_rank}.pth") rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
if not os.path.isfile(rng_file): if not os.path.isfile(rng_file):
logger.info( logger.info(
f"Didn't find an RNG file for process {local_rank}, if you are resuming a training that " f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
"wasn't launched in a distributed fashion, reproducibility is not guaranteed." "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
) )
return return
...@@ -2067,11 +2067,10 @@ class Trainer: ...@@ -2067,11 +2067,10 @@ class Trainer:
# not yet exist. # not yet exist.
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank if self.args.world_size <= 1:
if local_rank == -1:
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
else: else:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth")) torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
if self.args.push_to_hub: if self.args.push_to_hub:
self._push_from_checkpoint(output_dir) self._push_from_checkpoint(output_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment