Unverified Commit 3cc96e17 authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Avoid deadlock for unsynchronized non-zero checkpointing (#297)



* Avoid deadlock for unsynchronized non-zero checkpointing

* Fix formatting issues
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 871f7e63
......@@ -227,7 +227,8 @@ class DeepSpeedLight(Module):
if not dist_init_required and dist.is_initialized():
assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank())
assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(world_size, dist.get_world_size())
assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
world_size, dist.get_world_size())
def tensorboard_enabled(self):
return self._config.tensorboard_enabled
......@@ -375,7 +376,7 @@ class DeepSpeedLight(Module):
if self.mpu:
dp_rank = self.mpu.get_data_parallel_rank()
#only the first data parallel process needs to store the model checkpoint
# only the first data parallel process needs to store the model checkpoint
self.save_non_zero_checkpoint = (dp_rank == 0)
if self.zero_optimization():
......@@ -596,7 +597,8 @@ class DeepSpeedLight(Module):
dp_process_group=self.data_parallel_group,
mpu=self.mpu)
elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
assert self.gradient_accumulation_steps() == 1, "ZeRO stage 2 does not support gradient accumulation, if you need gradient accumulation please use stage 1"
assert self.gradient_accumulation_steps(
) == 1, "ZeRO stage 2 does not support gradient accumulation, if you need gradient accumulation please use stage 1"
optimizer = FP16_DeepSpeedZeroOptimizer(
optimizer,
timers=self.timers,
......@@ -831,8 +833,8 @@ class DeepSpeedLight(Module):
self.optimizer.step()
#zero grad in basic optimizer could be unreliable and may not exhibit
#the behaviour that we want
# zero grad in basic optimizer could be unreliable and may not exhibit
# the behaviour that we want
if not self.zero_optimization() and not self.fp16_enabled():
self.zero_grad()
else:
......@@ -883,11 +885,23 @@ class DeepSpeedLight(Module):
if self.is_gradient_accumulation_boundary():
if self.tensorboard_enabled():
if self.global_rank == 0:
self.summary_events = [(f'Train/Samples/elapsed_time_ms_forward', self.timers('forward').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_backward', self.timers('backward').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_backward_inner', self.timers('backward_inner').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_backward_allreduce', self.timers('backward_allreduce').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_step', self.timers('step').elapsed(reset=False) * 1000.0, self.sample_count)
self.summary_events = [
(f'Train/Samples/elapsed_time_ms_forward',
self.timers('forward').elapsed(reset=False) * 1000.0,
self.sample_count),
(f'Train/Samples/elapsed_time_ms_backward',
self.timers('backward').elapsed(reset=False) * 1000.0,
self.sample_count),
(f'Train/Samples/elapsed_time_ms_backward_inner',
self.timers('backward_inner').elapsed(reset=False) * 1000.0,
self.sample_count),
(f'Train/Samples/elapsed_time_ms_backward_allreduce',
self.timers('backward_allreduce').elapsed(reset=False) *
1000.0,
self.sample_count),
(f'Train/Samples/elapsed_time_ms_step',
self.timers('step').elapsed(reset=False) * 1000.0,
self.sample_count)
]
for event in self.summary_events: # write_summary_events
self.summary_writer.add_scalar(event[0], event[1], event[2])
......@@ -1260,40 +1274,45 @@ class DeepSpeedLight(Module):
client_state: Optional. State dictionary used for saving required training states in the client code.
"""
#This is to make sure the checkpoint names are created without collision
#There seems to be issue creating them in parallel
self._create_checkpoint_files(save_dir, tag)
# This is to make sure the checkpoint names are created without collision
# There seems to be issue creating them in parallel
if self.save_non_zero_checkpoint:
self._create_checkpoint_file(save_dir, tag, False)
self._save_checkpoint(save_dir, tag, client_state=client_state)
if self.save_zero_checkpoint:
self._create_zero_checkpoint_files(save_dir, tag)
self._save_zero_checkpoint(save_dir, tag)
return True
def _create_checkpoint_files(self, save_dir, tag):
#checkpoint files are created sequentially
for rank in range(self.world_size):
if rank == self.global_rank:
def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):
name_function = self._get_zero_ckpt_name if zero_checkpoint else self._get_ckpt_name
try:
if self.save_non_zero_checkpoint:
checkpoint_name = self._get_ckpt_name(save_dir, tag)
self._ensure_directory_exists(checkpoint_name)
if self.save_zero_checkpoint:
checkpoint_name = self._get_zero_ckpt_name(save_dir, tag)
checkpoint_name = name_function(save_dir, tag)
self._ensure_directory_exists(checkpoint_name)
except:
logger.error(
f'Failed Saving model checkpoint to {save_dir} with tag {tag}')
logger.error(f'Failed Saving model checkpoint to {save_dir} with tag {tag}')
return False
return True
def _create_zero_checkpoint_files(self, save_dir, tag):
success = True
# zero checkpoint files are created sequentially
for rank in range(self.world_size):
if rank == self.global_rank:
success = self._create_checkpoint_file(save_dir, tag, True)
dist.barrier()
return success
def _save_checkpoint(self, save_dir, tag, client_state={}):
save_path = self._get_ckpt_name(save_dir, tag)
#self._ensure_directory_exists(save_path)
# self._ensure_directory_exists(save_path)
state = {
'module':
......@@ -1321,7 +1340,7 @@ class DeepSpeedLight(Module):
def _save_zero_checkpoint(self, save_path, tag):
zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag)
#self._ensure_directory_exists(zero_checkpoint_name)
# self._ensure_directory_exists(zero_checkpoint_name)
zero_sd = {'optimizer_state_dict': self.optimizer.state_dict()}
torch.save(zero_sd, zero_checkpoint_name)
logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment