Unverified Commit d23cf5b1 authored by Victor Zhu's avatar Victor Zhu Committed by GitHub
Browse files

Add support for Sagemaker Model Parallel >= 1.10 new checkpoint API (#18221)

* Add support for Sagemaker Model Parallel >= 1.10 new checkpoint API

* Support loading checkpoints saved with SMP < 1.10 in SMP < 1.10 and SMP >= 1.10

* Support loading checkpoints saved with SMP >= 1.10 in SMP >= 1.10

* Fix bug and styling

* Update based on reviewer feedback
parent dbfeffd7
......@@ -28,11 +28,13 @@ from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from packaging import version
from torch import Tensor, device, nn
from torch.nn import CrossEntropyLoss
from requests import HTTPError
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from transformers.utils.import_utils import is_sagemaker_mp_enabled
from .activations import get_activation
from .configuration_utils import PretrainedConfig
......@@ -88,6 +90,15 @@ logger = logging.get_logger(__name__)
_init_weights = True
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
from smdistributed.modelparallel import __version__ as SMP_VERSION
IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
else:
IS_SAGEMAKER_MP_POST_1_10 = False
@contextmanager
def no_init_weights(_enable=True):
"""
......@@ -1520,6 +1531,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if state_dict is None:
state_dict = model_to_save.state_dict()
# Translate state_dict from smp to hf if saving with smp >= 1.10
if IS_SAGEMAKER_MP_POST_1_10:
for smp_to_hf, _ in smp.state.module_manager.translate_functions:
state_dict = smp_to_hf(state_dict)
# Handle the case where some state_dict keys shouldn't be saved
if self._keys_to_ignore_on_save is not None:
for ignore_key in self._keys_to_ignore_on_save:
......
......@@ -192,8 +192,13 @@ if is_fairscale_available():
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
from smdistributed.modelparallel import __version__ as SMP_VERSION
IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
else:
IS_SAGEMAKER_MP_POST_1_10 = False
if TYPE_CHECKING:
......@@ -504,6 +509,8 @@ class Trainer:
# BF16 + model parallelism in SageMaker: currently not supported, raise an error
if args.bf16:
raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
if IS_SAGEMAKER_MP_POST_1_10:
# When there's mismatch between SMP config and trainer argument, use SMP config as truth
if args.fp16 != smp.state.cfg.fp16:
logger.warning(
......@@ -512,6 +519,13 @@ class Trainer:
f"setting to {smp.state.cfg.fp16}"
)
args.fp16 = smp.state.cfg.fp16
else:
# smp < 1.10 does not support fp16 in trainer.
if hasattr(smp.state.cfg, "fp16"):
logger.warning(
f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
"but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
)
if args.fp16 or args.bf16:
if args.half_precision_backend == "auto":
......@@ -991,10 +1005,12 @@ class Trainer:
`create_scheduler`) in a subclass.
"""
self.create_optimizer()
self.create_scheduler(
num_training_steps=num_training_steps,
optimizer=self.optimizer.optimizer if is_sagemaker_mp_enabled() and smp.state.cfg.fp16 else self.optimizer,
)
if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
# If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
optimizer = self.optimizer.optimizer
else:
optimizer = self.optimizer
self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
def create_optimizer(self):
"""
......@@ -1858,7 +1874,6 @@ class Trainer:
if model is None:
model = self.model
strict_load = is_sagemaker_mp_enabled()
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile(
os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
......@@ -1881,24 +1896,43 @@ class Trainer:
# will be resumed in deepspeed_init
pass
elif os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
# If the model is on the GPU, it still works!
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
# If the 'user_content.pt' file exists, load with the new smp api.
# Checkpoint must have been saved with the new smp api.
smp.resume_from_checkpoint(
path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
)
else:
# If the 'user_content.pt' file does NOT exist, load with the old smp api.
# Checkpoint must have been saved with the old smp api.
if hasattr(self.args, "fp16") and self.args.fp16 is True:
logger.warning(
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
)
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
# release memory
del state_dict
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works!
load_result = model.load_state_dict(state_dict, strict=strict_load)
if not strict_load:
self._issue_warnings_after_load(load_result)
load_result = model.load_state_dict(state_dict)
# release memory
del state_dict
self._issue_warnings_after_load(load_result)
else:
# We load the sharded checkpoint
load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=strict_load)
if not strict_load:
load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled())
if not is_sagemaker_mp_enabled():
self._issue_warnings_after_load(load_result)
def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
strict_load = is_sagemaker_mp_enabled()
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if os.path.exists(best_model_path):
if self.deepspeed:
......@@ -1919,16 +1953,35 @@ class Trainer:
self.deepspeed = deepspeed_engine
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
else:
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
# If the 'user_content.pt' file exists, load with the new smp api.
# Checkpoint must have been saved with the new smp api.
smp.resume_from_checkpoint(
path=self.state.best_model_checkpoint,
tag=WEIGHTS_NAME,
partial=False,
load_optimizer=False,
)
else:
# If the 'user_content.pt' file does NOT exist, load with the old smp api.
# Checkpoint must have been saved with the old smp api.
state_dict = torch.load(best_model_path, map_location="cpu")
state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
# If the model is on the GPU, it still works!
load_result = model.load_state_dict(state_dict, strict=strict_load)
if not strict_load:
load_result = model.load_state_dict(state_dict)
if not is_sagemaker_mp_enabled():
self._issue_warnings_after_load(load_result)
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
load_result = load_sharded_checkpoint(model, self.state.best_model_checkpoint, strict=strict_load)
if not strict_load:
load_result = load_sharded_checkpoint(
model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled()
)
if not is_sagemaker_mp_enabled():
self._issue_warnings_after_load(load_result)
else:
logger.warning(
......@@ -2174,11 +2227,20 @@ class Trainer:
else:
map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
# Optimizer checkpoint was saved with smp >= 1.10
def opt_load_hook(mod, opt):
opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
else:
# Optimizer checkpoint was saved with smp < 1.10
def opt_load_hook(mod, opt):
if IS_SAGEMAKER_MP_POST_1_10:
opt.load_state_dict(
smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True), gather_if_shard=False
smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)
)
else:
opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
self.model_wrapped.register_post_step_hook(opt_load_hook)
else:
......@@ -2479,9 +2541,13 @@ class Trainer:
self._save_tpu(output_dir)
elif is_sagemaker_mp_enabled():
# Calling the state_dict needs to be done on the wrapped model and on all processes.
os.makedirs(output_dir, exist_ok=True)
state_dict = self.model_wrapped.state_dict()
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
if IS_SAGEMAKER_MP_POST_1_10:
# 'user_content.pt' indicates model state_dict saved with smp >= 1.10
Path(os.path.join(output_dir, "user_content.pt")).touch()
elif (
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment