"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ae32f3afefcd3288df0af47d8499ae6024c66612"
Unverified Commit 4d0ea3d2 authored by Shivam Shrirao's avatar Shivam Shrirao Committed by GitHub
Browse files

Cuda rng_state_all is used when saving in distributed mode so same should also...

Cuda rng_state_all is used when saving in distributed mode so same should also be used when loading (#23045)

cuda rng state should be all for distributed bc all were saved
parent 521a8ffa
...@@ -2327,10 +2327,10 @@ class Trainer: ...@@ -2327,10 +2327,10 @@ class Trainer:
torch.random.set_rng_state(checkpoint_rng_state["cpu"]) torch.random.set_rng_state(checkpoint_rng_state["cpu"])
if torch.cuda.is_available(): if torch.cuda.is_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED: if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"]) torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
else: else:
try: try:
torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"]) torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
except Exception as e: except Exception as e:
logger.info( logger.info(
f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}" f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment