Unverified Commit 845921b3 authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

Add 'latest' checkpoint save/load support (#569)

parent 7a75f8b3
...@@ -1300,15 +1300,15 @@ class DeepSpeedEngine(Module): ...@@ -1300,15 +1300,15 @@ class DeepSpeedEngine(Module):
def load_checkpoint(self, def load_checkpoint(self,
load_dir, load_dir,
tag, tag=None,
load_module_strict=True, load_module_strict=True,
load_optimizer_states=True, load_optimizer_states=True,
load_lr_scheduler_states=True): load_lr_scheduler_states=True):
r"""Load training checkpoint """Load training checkpoint
Arguments: Arguments:
load_dir: Required. Directory to load the checkpoint from load_dir: Required. Directory to load the checkpoint from
tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step. tag: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file
load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match. load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.
load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance
load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint. load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
...@@ -1317,6 +1317,13 @@ class DeepSpeedEngine(Module): ...@@ -1317,6 +1317,13 @@ class DeepSpeedEngine(Module):
client_state: State dictionary used for loading required training states in the client code. client_state: State dictionary used for loading required training states in the client code.
""" """
if tag is None:
latest_path = os.path.join(load_dir, 'latest')
assert os.path.isfile(latest_path), f"Unable to find latest file at {latest_path}, if trying to load latest " \
"checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint."
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
load_path, client_states = self._load_checkpoint(load_dir, load_path, client_states = self._load_checkpoint(load_dir,
tag, tag,
load_module_strict=load_module_strict, load_module_strict=load_module_strict,
...@@ -1454,18 +1461,25 @@ class DeepSpeedEngine(Module): ...@@ -1454,18 +1461,25 @@ class DeepSpeedEngine(Module):
) )
return zero_optimizer_sd return zero_optimizer_sd
def save_checkpoint(self, save_dir, tag, client_state={}): def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True):
r"""Save training checkpoint r"""Save training checkpoint
Arguments: Arguments:
save_dir: Required. Directory for saving the checkpoint save_dir: Required. Directory for saving the checkpoint
tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step. tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is used if not provided.
client_state: Optional. State dictionary used for saving required training states in the client code. client_state: Optional. State dictionary used for saving required training states in the client code.
save_latest: Optional. Save a file 'latest' pointing to the latest saved checkpoint.
""" """
# This is to make sure the checkpoint names are created without collision # This is to make sure the checkpoint names are created without collision
# There seems to be issue creating them in parallel # There seems to be issue creating them in parallel
# Ensure save_dir directory exists
os.makedirs(save_dir, exist_ok=True)
if tag is None:
tag = f"global_step{self.global_steps}"
if self.save_non_zero_checkpoint: if self.save_non_zero_checkpoint:
self._create_checkpoint_file(save_dir, tag, False) self._create_checkpoint_file(save_dir, tag, False)
self._save_checkpoint(save_dir, tag, client_state=client_state) self._save_checkpoint(save_dir, tag, client_state=client_state)
...@@ -1474,6 +1488,11 @@ class DeepSpeedEngine(Module): ...@@ -1474,6 +1488,11 @@ class DeepSpeedEngine(Module):
self._create_zero_checkpoint_files(save_dir, tag) self._create_zero_checkpoint_files(save_dir, tag)
self._save_zero_checkpoint(save_dir, tag) self._save_zero_checkpoint(save_dir, tag)
# Save latest checkpoint tag
if save_latest:
with open(os.path.join(save_dir, 'latest'), 'w') as fd:
fd.write(tag)
return True return True
def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint): def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):
......
...@@ -128,7 +128,8 @@ def checkpoint_correctness_verification(args, ...@@ -128,7 +128,8 @@ def checkpoint_correctness_verification(args,
fp16=True, fp16=True,
train_batch=False, train_batch=False,
base_optimizers=[None, base_optimizers=[None,
None]): None],
empty_tag=False):
dtype = torch.half if fp16 else torch.float32 dtype = torch.half if fp16 else torch.float32
ds_model = create_deepspeed_model(args=args, ds_model = create_deepspeed_model(args=args,
model=models[0], model=models[0],
...@@ -153,16 +154,16 @@ def checkpoint_correctness_verification(args, ...@@ -153,16 +154,16 @@ def checkpoint_correctness_verification(args,
trained_model = ds_model trained_model = ds_model
save_folder = os.path.join(tmpdir, 'saved_checkpoint') save_folder = os.path.join(tmpdir, 'saved_checkpoint')
save_tag = '1' save_tag = None if empty_tag else '1'
trained_model.save_checkpoint(save_folder, save_tag) trained_model.save_checkpoint(save_folder, tag=save_tag)
loaded_model = create_deepspeed_model(args=args, loaded_model = create_deepspeed_model(args=args,
model=models[1], model=models[1],
base_optimizer=base_optimizers[1]) base_optimizer=base_optimizers[1])
loaded_model.load_checkpoint(save_folder, loaded_model.load_checkpoint(save_folder,
save_tag, tag=save_tag,
load_optimizer_states=load_optimizer_states, load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states) load_lr_scheduler_states=load_lr_scheduler_states)
...@@ -704,3 +705,59 @@ def test_checkpoint_zero_hybrid_optimizer_state(tmpdir, zero_stage): ...@@ -704,3 +705,59 @@ def test_checkpoint_zero_hybrid_optimizer_state(tmpdir, zero_stage):
models=models, models=models,
optimizers=optimizers, optimizers=optimizers,
hidden_dim=hidden_dim) hidden_dim=hidden_dim)
def test_checkpoint_latest(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
}
}
hidden_dim = 10
args = args_from_dict(tmpdir, config_dict)
models = [SimpleModel(hidden_dim=hidden_dim) for _ in range(2)]
@distributed_test(world_size=[1])
def _helper(args, models):
checkpoint_correctness_verification(args,
models=models,
hidden_dim=hidden_dim,
tmpdir=tmpdir,
load_optimizer_states=True,
load_lr_scheduler_states=False,
fp16=False,
empty_tag=True)
_helper(args, models)
def test_checkpoint_missing_latest(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
}
}
hidden_dim = 10
args = args_from_dict(tmpdir, config_dict)
model = SimpleModel(hidden_dim, rank=args.local_rank)
@distributed_test(world_size=[1])
def _helper(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
with pytest.raises(AssertionError):
model.load_checkpoint(tmpdir)
_helper(args=args, model=model, hidden_dim=hidden_dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment