Unverified Commit f032e56f authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

Validate consistent ckpt tags across ranks (#667)

parent 981bc7d4
...@@ -10,7 +10,7 @@ from .runtime.engine import DeepSpeedEngine ...@@ -10,7 +10,7 @@ from .runtime.engine import DeepSpeedEngine
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from .runtime.pipe.engine import PipelineEngine from .runtime.pipe.engine import PipelineEngine
from .runtime.lr_schedules import add_tuning_arguments from .runtime.lr_schedules import add_tuning_arguments
from .runtime.config import DeepSpeedConfig from .runtime.config import DeepSpeedConfig, DeepSpeedConfigError
from .runtime.activation_checkpointing import checkpointing from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .utils import log_dist from .utils import log_dist
......
...@@ -40,6 +40,10 @@ TORCH_ADAM_PARAM = "torch_adam" ...@@ -40,6 +40,10 @@ TORCH_ADAM_PARAM = "torch_adam"
ADAM_W_MODE_PARAM = "adam_w_mode" ADAM_W_MODE_PARAM = "adam_w_mode"
class DeepSpeedConfigError(Exception):
pass
def get_pld_enabled(param_dict): def get_pld_enabled(param_dict):
if PROGRESSIVE_LAYER_DROP in param_dict.keys(): if PROGRESSIVE_LAYER_DROP in param_dict.keys():
return get_scalar_param(param_dict[PROGRESSIVE_LAYER_DROP], return get_scalar_param(param_dict[PROGRESSIVE_LAYER_DROP],
...@@ -471,6 +475,21 @@ def get_tensorboard_job_name(param_dict): ...@@ -471,6 +475,21 @@ def get_tensorboard_job_name(param_dict):
return TENSORBOARD_JOB_NAME_DEFAULT return TENSORBOARD_JOB_NAME_DEFAULT
def get_checkpoint_params(param_dict):
return param_dict.get(CHECKPOINT, {})
def get_checkpoint_tag_validation_mode(checkpoint_params):
tag_validation_mode = checkpoint_params.get(CHECKPOINT_TAG_VALIDATION,
CHECKPOINT_TAG_VALIDATION_DEFAULT)
tag_validation_mode = tag_validation_mode.upper()
if tag_validation_mode in CHECKPOINT_TAG_VALIDATION_MODES:
return tag_validation_mode
else:
raise DeepSpeedConfigError("Checkpoint config contains invalid tag_validation " \
f"value of {tag_validation_mode}, expecting one of {CHECKPOINT_TAG_VALIDATION_MODES}")
'''Write deepspeed config files by modifying basic templates. '''Write deepspeed config files by modifying basic templates.
Can be used for quicly changing parameters via command line parameters.''' Can be used for quicly changing parameters via command line parameters.'''
...@@ -627,6 +646,11 @@ class DeepSpeedConfig(object): ...@@ -627,6 +646,11 @@ class DeepSpeedConfig(object):
self.pld_enabled = get_pld_enabled(param_dict) self.pld_enabled = get_pld_enabled(param_dict)
self.pld_params = get_pld_params(param_dict) self.pld_params = get_pld_params(param_dict)
checkpoint_params = get_checkpoint_params(param_dict)
validation_mode = get_checkpoint_tag_validation_mode(checkpoint_params)
self.checkpoint_tag_validation_enabled = validation_mode != ValidationMode.IGNORE
self.checkpoint_tag_validation_fail = validation_mode == ValidationMode.FAIL
def _batch_assertion(self): def _batch_assertion(self):
train_batch = self.train_batch_size train_batch = self.train_batch_size
......
...@@ -287,7 +287,9 @@ TENSORBOARD_OUTPUT_PATH_DEFAULT = "" ...@@ -287,7 +287,9 @@ TENSORBOARD_OUTPUT_PATH_DEFAULT = ""
TENSORBOARD_JOB_NAME = "job_name" TENSORBOARD_JOB_NAME = "job_name"
TENSORBOARD_JOB_NAME_DEFAULT = "DeepSpeedJobName" TENSORBOARD_JOB_NAME_DEFAULT = "DeepSpeedJobName"
#########################################
# Progressive Layer Drop (PLD) # Progressive Layer Drop (PLD)
#########################################
PROGRESSIVE_LAYER_DROP = "progressive_layer_drop" PROGRESSIVE_LAYER_DROP = "progressive_layer_drop"
# PLD enable signal # PLD enable signal
...@@ -299,3 +301,26 @@ PLD_THETA_DEFAULT = 1.0 ...@@ -299,3 +301,26 @@ PLD_THETA_DEFAULT = 1.0
PLD_GAMMA = "gamma" PLD_GAMMA = "gamma"
PLD_GAMMA_DEFAULT = 0.001 PLD_GAMMA_DEFAULT = 0.001
#########################################
# Validation modes
#########################################
class ValidationMode:
WARN = "WARN"
IGNORE = "IGNORE"
FAIL = "FAIL"
#########################################
# Checkpoint config params
#########################################
# "checkpoint": {tag_validation=["Ignore"|"Warn"|"Fail"]}
CHECKPOINT = "checkpoint"
CHECKPOINT_TAG_VALIDATION = "tag_validation"
CHECKPOINT_TAG_VALIDATION_DEFAULT = ValidationMode.WARN
CHECKPOINT_TAG_VALIDATION_MODES = [
ValidationMode.WARN,
ValidationMode.IGNORE,
ValidationMode.FAIL
]
...@@ -5,6 +5,7 @@ Copyright 2019 The Microsoft DeepSpeed Team ...@@ -5,6 +5,7 @@ Copyright 2019 The Microsoft DeepSpeed Team
import os import os
import torch import torch
import warnings import warnings
import hashlib
import torch.distributed as dist import torch.distributed as dist
from torch.nn.modules import Module from torch.nn.modules import Module
...@@ -213,6 +214,12 @@ class DeepSpeedEngine(Module): ...@@ -213,6 +214,12 @@ class DeepSpeedEngine(Module):
""" """
return self.train_batch_size, self.train_micro_batch_size_per_gpu, self.gradient_accumulation_steps return self.train_batch_size, self.train_micro_batch_size_per_gpu, self.gradient_accumulation_steps
def checkpoint_tag_validation_enabled(self):
return self._config.checkpoint_tag_validation_enabled
def checkpoint_tag_validation_fail(self):
return self._config.checkpoint_tag_validation_fail
def elasticity_enabled(self): def elasticity_enabled(self):
return self._config.elasticity_enabled return self._config.elasticity_enabled
...@@ -1435,12 +1442,30 @@ class DeepSpeedEngine(Module): ...@@ -1435,12 +1442,30 @@ class DeepSpeedEngine(Module):
) )
return zero_optimizer_sd return zero_optimizer_sd
def _checkpoint_tag_validation(self, tag):
if self.checkpoint_tag_validation_enabled():
s_hash = hashlib.sha1(tag.encode())
bhash = torch.ByteTensor([s_hash.digest()]).flatten().to(self.device)
max_bhash = bhash.clone()
min_bhash = bhash.clone()
dist.all_reduce(max_bhash, op=torch.distributed.ReduceOp.MAX)
dist.all_reduce(min_bhash, op=torch.distributed.ReduceOp.MIN)
valid = all(min_bhash == bhash) and all(max_bhash == bhash)
msg = f"[rank={dist.get_rank()}] The checkpoint tag name '{tag}' is not consistent across " \
"all ranks. Including rank unique information in checkpoint tag could cause issues when " \
"restoring with different world sizes."
if self.checkpoint_tag_validation_fail():
assert valid, msg
elif not valid:
logger.warning(msg)
def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True): def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True):
r"""Save training checkpoint r"""Save training checkpoint
Arguments: Arguments:
save_dir: Required. Directory for saving the checkpoint save_dir: Required. Directory for saving the checkpoint
tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is used if not provided. tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is
used if not provided. Tag name must be the same across all ranks.
client_state: Optional. State dictionary used for saving required training states in the client code. client_state: Optional. State dictionary used for saving required training states in the client code.
save_latest: Optional. Save a file 'latest' pointing to the latest saved checkpoint. save_latest: Optional. Save a file 'latest' pointing to the latest saved checkpoint.
""" """
...@@ -1454,6 +1479,9 @@ class DeepSpeedEngine(Module): ...@@ -1454,6 +1479,9 @@ class DeepSpeedEngine(Module):
if tag is None: if tag is None:
tag = f"global_step{self.global_steps}" tag = f"global_step{self.global_steps}"
# Ensure checkpoint tag is consistent across ranks
self._checkpoint_tag_validation(tag)
if self.save_non_zero_checkpoint: if self.save_non_zero_checkpoint:
self._create_checkpoint_file(save_dir, tag, False) self._create_checkpoint_file(save_dir, tag, False)
self._save_checkpoint(save_dir, tag, client_state=client_state) self._save_checkpoint(save_dir, tag, client_state=client_state)
......
...@@ -761,3 +761,68 @@ def test_checkpoint_missing_latest(tmpdir): ...@@ -761,3 +761,68 @@ def test_checkpoint_missing_latest(tmpdir):
model.load_checkpoint(tmpdir) model.load_checkpoint(tmpdir)
_helper(args=args, model=model, hidden_dim=hidden_dim) _helper(args=args, model=model, hidden_dim=hidden_dim)
@pytest.mark.parametrize('valid_mode', ["FAIL", "WARN", "IGNORE"])
def test_checkpoint_unique_tag(tmpdir, valid_mode):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"checkpoint": {
"tag_validation": valid_mode
}
}
hidden_dim = 10
args = args_from_dict(tmpdir, config_dict)
model = SimpleModel(hidden_dim, rank=args.local_rank)
@distributed_test(world_size=[2])
def _helper(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
if valid_mode == "FAIL":
with pytest.raises(AssertionError):
model.save_checkpoint(save_dir=tmpdir,
tag=f"tag-{torch.distributed.get_rank()}")
else:
model.save_checkpoint(save_dir=tmpdir,
tag=f"tag-{torch.distributed.get_rank()}")
_helper(args=args, model=model, hidden_dim=hidden_dim)
def test_checkpoint_unknown_tag_validation(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"checkpoint": {
"tag_validation": "foo"
}
}
hidden_dim = 10
args = args_from_dict(tmpdir, config_dict)
model = SimpleModel(hidden_dim, rank=args.local_rank)
@distributed_test(world_size=[1])
def _helper(args, model, hidden_dim):
with pytest.raises(deepspeed.DeepSpeedConfigError):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
_helper(args=args, model=model, hidden_dim=hidden_dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment