"kubernetes/helm/vscode:/vscode.git/clone" did not exist on "331fe04df7dcfb2d22e2ecc39525b6cf74fae575"
Unverified Commit c52e429b authored by Hz, Ji's avatar Hz, Ji Committed by GitHub
Browse files

Reproducible checkpoint for npu (#27208)

* save NPU's RNG states when saving a checkpoint and set after all the
data skip phase when resuming training.

* re-trigger ci

* re-trigger ci
parent 7adaefe2
......@@ -144,6 +144,7 @@ from .utils import (
is_sagemaker_mp_enabled,
is_torch_compile_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tpu_available,
logging,
strtobool,
......@@ -2321,6 +2322,17 @@ class Trainer:
)
if is_torch_tpu_available():
xm.set_rng_state(checkpoint_rng_state["xla"])
if is_torch_npu_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
torch.npu.random.set_rng_state_all(checkpoint_rng_state["npu"])
else:
try:
torch.npu.random.set_rng_state(checkpoint_rng_state["npu"])
except Exception as e:
logger.info(
f"Didn't manage to set back the RNG states of the NPU because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted."
)
def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
......@@ -2423,6 +2435,12 @@ class Trainer:
if is_torch_tpu_available():
rng_states["xla"] = xm.get_rng_state()
if is_torch_npu_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
rng_states["npu"] = torch.npu.random.get_rng_state_all()
else:
rng_states["npu"] = torch.npu.random.get_rng_state()
# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
# not yet exist.
os.makedirs(output_dir, exist_ok=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment