Unverified Commit 518dd127 authored by cavdard's avatar cavdard Committed by GitHub
Browse files

Updated checkpoint support for Sagemaker Model Parallel (#17219)



* adding partial checkpoint support for optimizer state

* formatted trainer.py

* Refactoring based on comments

* reformatting

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarCavdar <dcavdar@a07817b12d7e.ant.amazon.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 71d18d08
......@@ -18,6 +18,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
import contextlib
import functools
import glob
import inspect
import math
import os
......@@ -1305,7 +1306,7 @@ class Trainer:
if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
if resume_from_checkpoint is not None:
if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled():
self._load_from_checkpoint(resume_from_checkpoint)
# If model was re-initialized, put it on the right device and update self.model_wrapped
......@@ -1406,6 +1407,9 @@ class Trainer:
model = self._wrap_model(self.model_wrapped)
if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
self._load_from_checkpoint(resume_from_checkpoint, model)
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
......@@ -1671,6 +1675,8 @@ class Trainer:
xm.rendezvous("load_best_model_at_end")
elif args.local_rank != -1:
dist.barrier()
elif is_sagemaker_mp_enabled():
smp.barrier()
self._load_best_model()
......@@ -1693,7 +1699,12 @@ class Trainer:
return TrainOutput(self.state.global_step, train_loss, metrics)
def _load_from_checkpoint(self, resume_from_checkpoint):
def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
if model is None:
model = self.model
strict_load = is_sagemaker_mp_enabled()
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile(
os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
):
......@@ -1718,20 +1729,22 @@ class Trainer:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works!
load_result = self.model.load_state_dict(state_dict, strict=False)
self._issue_warnings_after_load(load_result)
load_result = model.load_state_dict(state_dict, strict=strict_load)
if not strict_load:
self._issue_warnings_after_load(load_result)
# release memory
del state_dict
else:
# We load the sharded checkpoint
load_result = load_sharded_checkpoint(self.model, resume_from_checkpoint, strict=False)
self._issue_warnings_after_load(load_result)
load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=strict_load)
if not strict_load:
self._issue_warnings_after_load(load_result)
def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
strict_load = is_sagemaker_mp_enabled()
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if os.path.exists(best_model_path):
if self.deepspeed:
# temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
......@@ -1748,12 +1761,13 @@ class Trainer:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
# If the model is on the GPU, it still works!
load_result = self.model.load_state_dict(state_dict, strict=False)
self._issue_warnings_after_load(load_result)
load_result = model.load_state_dict(state_dict, strict=strict_load)
if not strict_load:
self._issue_warnings_after_load(load_result)
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
# Best model is a sharded checkpoint
load_result = load_sharded_checkpoint(self.model, self.state.best_model_checkpoint, strict=False)
self._issue_warnings_after_load(load_result)
load_result = load_sharded_checkpoint(model, self.state.best_model_checkpoint, strict=strict_load)
if not strict_load:
self._issue_warnings_after_load(load_result)
else:
logger.warning(
f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
......@@ -1891,17 +1905,21 @@ class Trainer:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled():
if smp.rdp_rank() == 0:
# Consolidate the state dict on all processed of rdp_rank 0
opt_state_dict = self.optimizer.state_dict()
# Save it and the scheduler on the main process
if self.args.should_save:
torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
smp.barrier()
if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
smp.save(
opt_state_dict,
os.path.join(output_dir, OPTIMIZER_NAME),
partial=True,
v3=smp.state.cfg.shard_optimizer_state,
)
if self.args.should_save:
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
......@@ -1950,6 +1968,7 @@ class Trainer:
# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
# not yet exist.
os.makedirs(output_dir, exist_ok=True)
local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
if local_rank == -1:
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
......@@ -1972,9 +1991,12 @@ class Trainer:
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
return
if os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) and os.path.isfile(
os.path.join(checkpoint, SCHEDULER_NAME)
):
checkpoint_file_exists = (
glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
if is_sagemaker_mp_enabled()
else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
)
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
# Load in optimizer and scheduler states
if is_torch_tpu_available():
# On TPU we have to take some extra precautions to properly load the states on the right device.
......@@ -1990,9 +2012,16 @@ class Trainer:
self.lr_scheduler.load_state_dict(lr_scheduler_state)
else:
map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
)
if is_sagemaker_mp_enabled():
def opt_load_hook(mod, opt):
opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
self.model_wrapped.register_post_step_hook(opt_load_hook)
else:
self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
)
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment