Unverified Commit 518dd127 authored by cavdard's avatar cavdard Committed by GitHub
Browse files

Updated checkpoint support for Sagemaker Model Parallel (#17219)



* adding partial checkpoint support for optimizer state

* formatted trainer.py

* Refactoring based on comments

* reformatting

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarCavdar <dcavdar@a07817b12d7e.ant.amazon.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 71d18d08
...@@ -18,6 +18,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune ...@@ -18,6 +18,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
import contextlib import contextlib
import functools import functools
import glob
import inspect import inspect
import math import math
import os import os
...@@ -1305,7 +1306,7 @@ class Trainer: ...@@ -1305,7 +1306,7 @@ class Trainer:
if resume_from_checkpoint is None: if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
if resume_from_checkpoint is not None: if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled():
self._load_from_checkpoint(resume_from_checkpoint) self._load_from_checkpoint(resume_from_checkpoint)
# If model was re-initialized, put it on the right device and update self.model_wrapped # If model was re-initialized, put it on the right device and update self.model_wrapped
...@@ -1406,6 +1407,9 @@ class Trainer: ...@@ -1406,6 +1407,9 @@ class Trainer:
model = self._wrap_model(self.model_wrapped) model = self._wrap_model(self.model_wrapped)
if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
self._load_from_checkpoint(resume_from_checkpoint, model)
# for the rest of this function `model` is the outside model, whether it was wrapped or not # for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model: if model is not self.model:
self.model_wrapped = model self.model_wrapped = model
...@@ -1671,6 +1675,8 @@ class Trainer: ...@@ -1671,6 +1675,8 @@ class Trainer:
xm.rendezvous("load_best_model_at_end") xm.rendezvous("load_best_model_at_end")
elif args.local_rank != -1: elif args.local_rank != -1:
dist.barrier() dist.barrier()
elif is_sagemaker_mp_enabled():
smp.barrier()
self._load_best_model() self._load_best_model()
...@@ -1693,7 +1699,12 @@ class Trainer: ...@@ -1693,7 +1699,12 @@ class Trainer:
return TrainOutput(self.state.global_step, train_loss, metrics) return TrainOutput(self.state.global_step, train_loss, metrics)
def _load_from_checkpoint(self, resume_from_checkpoint): def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
if model is None:
model = self.model
strict_load = is_sagemaker_mp_enabled()
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile( if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile(
os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
): ):
...@@ -1718,20 +1729,22 @@ class Trainer: ...@@ -1718,20 +1729,22 @@ class Trainer:
# We load the model state dict on the CPU to avoid an OOM error. # We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works! # If the model is on the GPU, it still works!
load_result = self.model.load_state_dict(state_dict, strict=False) load_result = model.load_state_dict(state_dict, strict=strict_load)
self._issue_warnings_after_load(load_result) if not strict_load:
self._issue_warnings_after_load(load_result)
# release memory # release memory
del state_dict del state_dict
else: else:
# We load the sharded checkpoint # We load the sharded checkpoint
load_result = load_sharded_checkpoint(self.model, resume_from_checkpoint, strict=False) load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=strict_load)
self._issue_warnings_after_load(load_result) if not strict_load:
self._issue_warnings_after_load(load_result)
def _load_best_model(self): def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
strict_load = is_sagemaker_mp_enabled()
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if os.path.exists(best_model_path): if os.path.exists(best_model_path):
if self.deepspeed: if self.deepspeed:
# temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
...@@ -1748,12 +1761,13 @@ class Trainer: ...@@ -1748,12 +1761,13 @@ class Trainer:
# We load the model state dict on the CPU to avoid an OOM error. # We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu") state_dict = torch.load(best_model_path, map_location="cpu")
# If the model is on the GPU, it still works! # If the model is on the GPU, it still works!
load_result = self.model.load_state_dict(state_dict, strict=False) load_result = model.load_state_dict(state_dict, strict=strict_load)
self._issue_warnings_after_load(load_result) if not strict_load:
self._issue_warnings_after_load(load_result)
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
# Best model is a sharded checkpoint load_result = load_sharded_checkpoint(model, self.state.best_model_checkpoint, strict=strict_load)
load_result = load_sharded_checkpoint(self.model, self.state.best_model_checkpoint, strict=False) if not strict_load:
self._issue_warnings_after_load(load_result) self._issue_warnings_after_load(load_result)
else: else:
logger.warning( logger.warning(
f"Could not locate the best model at {best_model_path}, if you are running a distributed training " f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
...@@ -1891,17 +1905,21 @@ class Trainer: ...@@ -1891,17 +1905,21 @@ class Trainer:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
if smp.rdp_rank() == 0: opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
# Consolidate the state dict on all processed of rdp_rank 0 smp.barrier()
opt_state_dict = self.optimizer.state_dict() if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
# Save it and the scheduler on the main process smp.save(
if self.args.should_save: opt_state_dict,
torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME)) os.path.join(output_dir, OPTIMIZER_NAME),
with warnings.catch_warnings(record=True) as caught_warnings: partial=True,
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) v3=smp.state.cfg.shard_optimizer_state,
reissue_pt_warnings(caught_warnings) )
if self.do_grad_scaling: if self.args.should_save:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.deepspeed: elif self.args.should_save and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched # deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
...@@ -1950,6 +1968,7 @@ class Trainer: ...@@ -1950,6 +1968,7 @@ class Trainer:
# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
# not yet exist. # not yet exist.
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
if local_rank == -1: if local_rank == -1:
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
...@@ -1972,9 +1991,12 @@ class Trainer: ...@@ -1972,9 +1991,12 @@ class Trainer:
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
return return
if os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) and os.path.isfile( checkpoint_file_exists = (
os.path.join(checkpoint, SCHEDULER_NAME) glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
): if is_sagemaker_mp_enabled()
else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
)
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
# Load in optimizer and scheduler states # Load in optimizer and scheduler states
if is_torch_tpu_available(): if is_torch_tpu_available():
# On TPU we have to take some extra precautions to properly load the states on the right device. # On TPU we have to take some extra precautions to properly load the states on the right device.
...@@ -1990,9 +2012,16 @@ class Trainer: ...@@ -1990,9 +2012,16 @@ class Trainer:
self.lr_scheduler.load_state_dict(lr_scheduler_state) self.lr_scheduler.load_state_dict(lr_scheduler_state)
else: else:
map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
self.optimizer.load_state_dict( if is_sagemaker_mp_enabled():
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
) def opt_load_hook(mod, opt):
opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
self.model_wrapped.register_post_step_hook(opt_load_hook)
else:
self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
)
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment