Unverified Commit d23cf5b1 authored by Victor Zhu's avatar Victor Zhu Committed by GitHub
Browse files

Add support for Sagemaker Model Parallel >= 1.10 new checkpoint API (#18221)

* Add support for Sagemaker Model Parallel >= 1.10 new checkpoint API

* Support loading checkpoints saved with SMP < 1.10 in SMP < 1.10 and SMP >= 1.10

* Support loading checkpoints saved with SMP >= 1.10 in SMP >= 1.10

* Fix bug and styling

* Update based on reviewer feedback
parent dbfeffd7
...@@ -28,11 +28,13 @@ from pathlib import Path ...@@ -28,11 +28,13 @@ from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from packaging import version
from torch import Tensor, device, nn from torch import Tensor, device, nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from requests import HTTPError from requests import HTTPError
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from transformers.utils.import_utils import is_sagemaker_mp_enabled
from .activations import get_activation from .activations import get_activation
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
...@@ -88,6 +90,15 @@ logger = logging.get_logger(__name__) ...@@ -88,6 +90,15 @@ logger = logging.get_logger(__name__)
_init_weights = True _init_weights = True
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
from smdistributed.modelparallel import __version__ as SMP_VERSION
IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
else:
IS_SAGEMAKER_MP_POST_1_10 = False
@contextmanager @contextmanager
def no_init_weights(_enable=True): def no_init_weights(_enable=True):
""" """
...@@ -1520,6 +1531,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1520,6 +1531,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if state_dict is None: if state_dict is None:
state_dict = model_to_save.state_dict() state_dict = model_to_save.state_dict()
# Translate state_dict from smp to hf if saving with smp >= 1.10
if IS_SAGEMAKER_MP_POST_1_10:
for smp_to_hf, _ in smp.state.module_manager.translate_functions:
state_dict = smp_to_hf(state_dict)
# Handle the case where some state_dict keys shouldn't be saved # Handle the case where some state_dict keys shouldn't be saved
if self._keys_to_ignore_on_save is not None: if self._keys_to_ignore_on_save is not None:
for ignore_key in self._keys_to_ignore_on_save: for ignore_key in self._keys_to_ignore_on_save:
......
...@@ -192,8 +192,13 @@ if is_fairscale_available(): ...@@ -192,8 +192,13 @@ if is_fairscale_available():
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp import smdistributed.modelparallel.torch as smp
from smdistributed.modelparallel import __version__ as SMP_VERSION
IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
else:
IS_SAGEMAKER_MP_POST_1_10 = False
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -504,14 +509,23 @@ class Trainer: ...@@ -504,14 +509,23 @@ class Trainer:
# BF16 + model parallelism in SageMaker: currently not supported, raise an error # BF16 + model parallelism in SageMaker: currently not supported, raise an error
if args.bf16: if args.bf16:
raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ") raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
# When there's mismatch between SMP config and trainer argument, use SMP config as truth
if args.fp16 != smp.state.cfg.fp16: if IS_SAGEMAKER_MP_POST_1_10:
logger.warning( # When there's mismatch between SMP config and trainer argument, use SMP config as truth
f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}," if args.fp16 != smp.state.cfg.fp16:
f"but FP16 provided in trainer argument is {args.fp16}," logger.warning(
f"setting to {smp.state.cfg.fp16}" f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},"
) f"but FP16 provided in trainer argument is {args.fp16},"
args.fp16 = smp.state.cfg.fp16 f"setting to {smp.state.cfg.fp16}"
)
args.fp16 = smp.state.cfg.fp16
else:
# smp < 1.10 does not support fp16 in trainer.
if hasattr(smp.state.cfg, "fp16"):
logger.warning(
f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
"but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
)
if args.fp16 or args.bf16: if args.fp16 or args.bf16:
if args.half_precision_backend == "auto": if args.half_precision_backend == "auto":
...@@ -991,10 +1005,12 @@ class Trainer: ...@@ -991,10 +1005,12 @@ class Trainer:
`create_scheduler`) in a subclass. `create_scheduler`) in a subclass.
""" """
self.create_optimizer() self.create_optimizer()
self.create_scheduler( if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
num_training_steps=num_training_steps, # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
optimizer=self.optimizer.optimizer if is_sagemaker_mp_enabled() and smp.state.cfg.fp16 else self.optimizer, optimizer = self.optimizer.optimizer
) else:
optimizer = self.optimizer
self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
def create_optimizer(self): def create_optimizer(self):
""" """
...@@ -1858,7 +1874,6 @@ class Trainer: ...@@ -1858,7 +1874,6 @@ class Trainer:
if model is None: if model is None:
model = self.model model = self.model
strict_load = is_sagemaker_mp_enabled()
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile( if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile(
os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
...@@ -1881,24 +1896,43 @@ class Trainer: ...@@ -1881,24 +1896,43 @@ class Trainer:
# will be resumed in deepspeed_init # will be resumed in deepspeed_init
pass pass
elif os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): elif os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works! # If the model is on the GPU, it still works!
load_result = model.load_state_dict(state_dict, strict=strict_load) if is_sagemaker_mp_enabled():
if not strict_load: if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
# If the 'user_content.pt' file exists, load with the new smp api.
# Checkpoint must have been saved with the new smp api.
smp.resume_from_checkpoint(
path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
)
else:
# If the 'user_content.pt' file does NOT exist, load with the old smp api.
# Checkpoint must have been saved with the old smp api.
if hasattr(self.args, "fp16") and self.args.fp16 is True:
logger.warning(
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
)
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
# release memory
del state_dict
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
load_result = model.load_state_dict(state_dict)
# release memory
del state_dict
self._issue_warnings_after_load(load_result) self._issue_warnings_after_load(load_result)
# release memory
del state_dict
else: else:
# We load the sharded checkpoint # We load the sharded checkpoint
load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=strict_load) load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled())
if not strict_load: if not is_sagemaker_mp_enabled():
self._issue_warnings_after_load(load_result) self._issue_warnings_after_load(load_result)
def _load_best_model(self): def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
strict_load = is_sagemaker_mp_enabled()
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if os.path.exists(best_model_path): if os.path.exists(best_model_path):
if self.deepspeed: if self.deepspeed:
...@@ -1920,15 +1954,34 @@ class Trainer: ...@@ -1920,15 +1954,34 @@ class Trainer:
self.optimizer = optimizer self.optimizer = optimizer
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
else: else:
# We load the model state dict on the CPU to avoid an OOM error. if is_sagemaker_mp_enabled():
state_dict = torch.load(best_model_path, map_location="cpu") if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
# If the model is on the GPU, it still works! # If the 'user_content.pt' file exists, load with the new smp api.
load_result = model.load_state_dict(state_dict, strict=strict_load) # Checkpoint must have been saved with the new smp api.
if not strict_load: smp.resume_from_checkpoint(
path=self.state.best_model_checkpoint,
tag=WEIGHTS_NAME,
partial=False,
load_optimizer=False,
)
else:
# If the 'user_content.pt' file does NOT exist, load with the old smp api.
# Checkpoint must have been saved with the old smp api.
state_dict = torch.load(best_model_path, map_location="cpu")
state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
# If the model is on the GPU, it still works!
load_result = model.load_state_dict(state_dict)
if not is_sagemaker_mp_enabled():
self._issue_warnings_after_load(load_result) self._issue_warnings_after_load(load_result)
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
load_result = load_sharded_checkpoint(model, self.state.best_model_checkpoint, strict=strict_load) load_result = load_sharded_checkpoint(
if not strict_load: model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled()
)
if not is_sagemaker_mp_enabled():
self._issue_warnings_after_load(load_result) self._issue_warnings_after_load(load_result)
else: else:
logger.warning( logger.warning(
...@@ -2174,11 +2227,20 @@ class Trainer: ...@@ -2174,11 +2227,20 @@ class Trainer:
else: else:
map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
# Optimizer checkpoint was saved with smp >= 1.10
def opt_load_hook(mod, opt):
opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
def opt_load_hook(mod, opt): else:
opt.load_state_dict( # Optimizer checkpoint was saved with smp < 1.10
smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True), gather_if_shard=False def opt_load_hook(mod, opt):
) if IS_SAGEMAKER_MP_POST_1_10:
opt.load_state_dict(
smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)
)
else:
opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
self.model_wrapped.register_post_step_hook(opt_load_hook) self.model_wrapped.register_post_step_hook(opt_load_hook)
else: else:
...@@ -2479,9 +2541,13 @@ class Trainer: ...@@ -2479,9 +2541,13 @@ class Trainer:
self._save_tpu(output_dir) self._save_tpu(output_dir)
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
# Calling the state_dict needs to be done on the wrapped model and on all processes. # Calling the state_dict needs to be done on the wrapped model and on all processes.
os.makedirs(output_dir, exist_ok=True)
state_dict = self.model_wrapped.state_dict() state_dict = self.model_wrapped.state_dict()
if self.args.should_save: if self.args.should_save:
self._save(output_dir, state_dict=state_dict) self._save(output_dir, state_dict=state_dict)
if IS_SAGEMAKER_MP_POST_1_10:
# 'user_content.pt' indicates model state_dict saved with smp >= 1.10
Path(os.path.join(output_dir, "user_content.pt")).touch()
elif ( elif (
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment