"examples/community/imagic_stable_diffusion.py" did not exist on "b671cb092082bab9855d028802fccf703db03f16"
Unverified Commit f032e56f authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

Validate consistent ckpt tags across ranks (#667)

parent 981bc7d4
......@@ -10,7 +10,7 @@ from .runtime.engine import DeepSpeedEngine
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from .runtime.pipe.engine import PipelineEngine
from .runtime.lr_schedules import add_tuning_arguments
from .runtime.config import DeepSpeedConfig
from .runtime.config import DeepSpeedConfig, DeepSpeedConfigError
from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .utils import log_dist
......
......@@ -40,6 +40,10 @@ TORCH_ADAM_PARAM = "torch_adam"
ADAM_W_MODE_PARAM = "adam_w_mode"
class DeepSpeedConfigError(Exception):
pass
def get_pld_enabled(param_dict):
if PROGRESSIVE_LAYER_DROP in param_dict.keys():
return get_scalar_param(param_dict[PROGRESSIVE_LAYER_DROP],
......@@ -471,6 +475,21 @@ def get_tensorboard_job_name(param_dict):
return TENSORBOARD_JOB_NAME_DEFAULT
def get_checkpoint_params(param_dict):
return param_dict.get(CHECKPOINT, {})
def get_checkpoint_tag_validation_mode(checkpoint_params):
tag_validation_mode = checkpoint_params.get(CHECKPOINT_TAG_VALIDATION,
CHECKPOINT_TAG_VALIDATION_DEFAULT)
tag_validation_mode = tag_validation_mode.upper()
if tag_validation_mode in CHECKPOINT_TAG_VALIDATION_MODES:
return tag_validation_mode
else:
raise DeepSpeedConfigError("Checkpoint config contains invalid tag_validation " \
f"value of {tag_validation_mode}, expecting one of {CHECKPOINT_TAG_VALIDATION_MODES}")
'''Write deepspeed config files by modifying basic templates.
Can be used for quicly changing parameters via command line parameters.'''
......@@ -627,6 +646,11 @@ class DeepSpeedConfig(object):
self.pld_enabled = get_pld_enabled(param_dict)
self.pld_params = get_pld_params(param_dict)
checkpoint_params = get_checkpoint_params(param_dict)
validation_mode = get_checkpoint_tag_validation_mode(checkpoint_params)
self.checkpoint_tag_validation_enabled = validation_mode != ValidationMode.IGNORE
self.checkpoint_tag_validation_fail = validation_mode == ValidationMode.FAIL
def _batch_assertion(self):
train_batch = self.train_batch_size
......
......@@ -287,7 +287,9 @@ TENSORBOARD_OUTPUT_PATH_DEFAULT = ""
TENSORBOARD_JOB_NAME = "job_name"
TENSORBOARD_JOB_NAME_DEFAULT = "DeepSpeedJobName"
#########################################
# Progressive Layer Drop (PLD)
#########################################
PROGRESSIVE_LAYER_DROP = "progressive_layer_drop"
# PLD enable signal
......@@ -299,3 +301,26 @@ PLD_THETA_DEFAULT = 1.0
PLD_GAMMA = "gamma"
PLD_GAMMA_DEFAULT = 0.001
#########################################
# Validation modes
#########################################
class ValidationMode:
WARN = "WARN"
IGNORE = "IGNORE"
FAIL = "FAIL"
#########################################
# Checkpoint config params
#########################################
# "checkpoint": {tag_validation=["Ignore"|"Warn"|"Fail"]}
CHECKPOINT = "checkpoint"
CHECKPOINT_TAG_VALIDATION = "tag_validation"
CHECKPOINT_TAG_VALIDATION_DEFAULT = ValidationMode.WARN
CHECKPOINT_TAG_VALIDATION_MODES = [
ValidationMode.WARN,
ValidationMode.IGNORE,
ValidationMode.FAIL
]
......@@ -5,6 +5,7 @@ Copyright 2019 The Microsoft DeepSpeed Team
import os
import torch
import warnings
import hashlib
import torch.distributed as dist
from torch.nn.modules import Module
......@@ -213,6 +214,12 @@ class DeepSpeedEngine(Module):
"""
return self.train_batch_size, self.train_micro_batch_size_per_gpu, self.gradient_accumulation_steps
def checkpoint_tag_validation_enabled(self):
return self._config.checkpoint_tag_validation_enabled
def checkpoint_tag_validation_fail(self):
return self._config.checkpoint_tag_validation_fail
def elasticity_enabled(self):
return self._config.elasticity_enabled
......@@ -1435,12 +1442,30 @@ class DeepSpeedEngine(Module):
)
return zero_optimizer_sd
def _checkpoint_tag_validation(self, tag):
if self.checkpoint_tag_validation_enabled():
s_hash = hashlib.sha1(tag.encode())
bhash = torch.ByteTensor([s_hash.digest()]).flatten().to(self.device)
max_bhash = bhash.clone()
min_bhash = bhash.clone()
dist.all_reduce(max_bhash, op=torch.distributed.ReduceOp.MAX)
dist.all_reduce(min_bhash, op=torch.distributed.ReduceOp.MIN)
valid = all(min_bhash == bhash) and all(max_bhash == bhash)
msg = f"[rank={dist.get_rank()}] The checkpoint tag name '{tag}' is not consistent across " \
"all ranks. Including rank unique information in checkpoint tag could cause issues when " \
"restoring with different world sizes."
if self.checkpoint_tag_validation_fail():
assert valid, msg
elif not valid:
logger.warning(msg)
def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True):
r"""Save training checkpoint
Arguments:
save_dir: Required. Directory for saving the checkpoint
tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is used if not provided.
tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is
used if not provided. Tag name must be the same across all ranks.
client_state: Optional. State dictionary used for saving required training states in the client code.
save_latest: Optional. Save a file 'latest' pointing to the latest saved checkpoint.
"""
......@@ -1454,6 +1479,9 @@ class DeepSpeedEngine(Module):
if tag is None:
tag = f"global_step{self.global_steps}"
# Ensure checkpoint tag is consistent across ranks
self._checkpoint_tag_validation(tag)
if self.save_non_zero_checkpoint:
self._create_checkpoint_file(save_dir, tag, False)
self._save_checkpoint(save_dir, tag, client_state=client_state)
......
......@@ -761,3 +761,68 @@ def test_checkpoint_missing_latest(tmpdir):
model.load_checkpoint(tmpdir)
_helper(args=args, model=model, hidden_dim=hidden_dim)
@pytest.mark.parametrize('valid_mode', ["FAIL", "WARN", "IGNORE"])
def test_checkpoint_unique_tag(tmpdir, valid_mode):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"checkpoint": {
"tag_validation": valid_mode
}
}
hidden_dim = 10
args = args_from_dict(tmpdir, config_dict)
model = SimpleModel(hidden_dim, rank=args.local_rank)
@distributed_test(world_size=[2])
def _helper(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
if valid_mode == "FAIL":
with pytest.raises(AssertionError):
model.save_checkpoint(save_dir=tmpdir,
tag=f"tag-{torch.distributed.get_rank()}")
else:
model.save_checkpoint(save_dir=tmpdir,
tag=f"tag-{torch.distributed.get_rank()}")
_helper(args=args, model=model, hidden_dim=hidden_dim)
def test_checkpoint_unknown_tag_validation(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"checkpoint": {
"tag_validation": "foo"
}
}
hidden_dim = 10
args = args_from_dict(tmpdir, config_dict)
model = SimpleModel(hidden_dim, rank=args.local_rank)
@distributed_test(world_size=[1])
def _helper(args, model, hidden_dim):
with pytest.raises(deepspeed.DeepSpeedConfigError):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
_helper(args=args, model=model, hidden_dim=hidden_dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment