Unverified Commit 3cc96e17 authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Avoid deadlock for unsynchronized non-zero checkpointing (#297)



* Avoid deadlock for unsynchronized non-zero checkpointing

* Fix formatting issues
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 871f7e63
...@@ -227,7 +227,8 @@ class DeepSpeedLight(Module): ...@@ -227,7 +227,8 @@ class DeepSpeedLight(Module):
if not dist_init_required and dist.is_initialized(): if not dist_init_required and dist.is_initialized():
assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank()) assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank())
assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(world_size, dist.get_world_size()) assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
world_size, dist.get_world_size())
def tensorboard_enabled(self): def tensorboard_enabled(self):
return self._config.tensorboard_enabled return self._config.tensorboard_enabled
...@@ -375,7 +376,7 @@ class DeepSpeedLight(Module): ...@@ -375,7 +376,7 @@ class DeepSpeedLight(Module):
if self.mpu: if self.mpu:
dp_rank = self.mpu.get_data_parallel_rank() dp_rank = self.mpu.get_data_parallel_rank()
#only the first data parallel process needs to store the model checkpoint # only the first data parallel process needs to store the model checkpoint
self.save_non_zero_checkpoint = (dp_rank == 0) self.save_non_zero_checkpoint = (dp_rank == 0)
if self.zero_optimization(): if self.zero_optimization():
...@@ -596,7 +597,8 @@ class DeepSpeedLight(Module): ...@@ -596,7 +597,8 @@ class DeepSpeedLight(Module):
dp_process_group=self.data_parallel_group, dp_process_group=self.data_parallel_group,
mpu=self.mpu) mpu=self.mpu)
elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS: elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
assert self.gradient_accumulation_steps() == 1, "ZeRO stage 2 does not support gradient accumulation, if you need gradient accumulation please use stage 1" assert self.gradient_accumulation_steps(
) == 1, "ZeRO stage 2 does not support gradient accumulation, if you need gradient accumulation please use stage 1"
optimizer = FP16_DeepSpeedZeroOptimizer( optimizer = FP16_DeepSpeedZeroOptimizer(
optimizer, optimizer,
timers=self.timers, timers=self.timers,
...@@ -831,8 +833,8 @@ class DeepSpeedLight(Module): ...@@ -831,8 +833,8 @@ class DeepSpeedLight(Module):
self.optimizer.step() self.optimizer.step()
#zero grad in basic optimizer could be unreliable and may not exhibit # zero grad in basic optimizer could be unreliable and may not exhibit
#the behaviour that we want # the behaviour that we want
if not self.zero_optimization() and not self.fp16_enabled(): if not self.zero_optimization() and not self.fp16_enabled():
self.zero_grad() self.zero_grad()
else: else:
...@@ -883,11 +885,23 @@ class DeepSpeedLight(Module): ...@@ -883,11 +885,23 @@ class DeepSpeedLight(Module):
if self.is_gradient_accumulation_boundary(): if self.is_gradient_accumulation_boundary():
if self.tensorboard_enabled(): if self.tensorboard_enabled():
if self.global_rank == 0: if self.global_rank == 0:
self.summary_events = [(f'Train/Samples/elapsed_time_ms_forward', self.timers('forward').elapsed(reset=False) * 1000.0, self.sample_count), \ self.summary_events = [
(f'Train/Samples/elapsed_time_ms_backward', self.timers('backward').elapsed(reset=False) * 1000.0, self.sample_count), \ (f'Train/Samples/elapsed_time_ms_forward',
(f'Train/Samples/elapsed_time_ms_backward_inner', self.timers('backward_inner').elapsed(reset=False) * 1000.0, self.sample_count), \ self.timers('forward').elapsed(reset=False) * 1000.0,
(f'Train/Samples/elapsed_time_ms_backward_allreduce', self.timers('backward_allreduce').elapsed(reset=False) * 1000.0, self.sample_count), \ self.sample_count),
(f'Train/Samples/elapsed_time_ms_step', self.timers('step').elapsed(reset=False) * 1000.0, self.sample_count) (f'Train/Samples/elapsed_time_ms_backward',
self.timers('backward').elapsed(reset=False) * 1000.0,
self.sample_count),
(f'Train/Samples/elapsed_time_ms_backward_inner',
self.timers('backward_inner').elapsed(reset=False) * 1000.0,
self.sample_count),
(f'Train/Samples/elapsed_time_ms_backward_allreduce',
self.timers('backward_allreduce').elapsed(reset=False) *
1000.0,
self.sample_count),
(f'Train/Samples/elapsed_time_ms_step',
self.timers('step').elapsed(reset=False) * 1000.0,
self.sample_count)
] ]
for event in self.summary_events: # write_summary_events for event in self.summary_events: # write_summary_events
self.summary_writer.add_scalar(event[0], event[1], event[2]) self.summary_writer.add_scalar(event[0], event[1], event[2])
...@@ -1260,40 +1274,45 @@ class DeepSpeedLight(Module): ...@@ -1260,40 +1274,45 @@ class DeepSpeedLight(Module):
client_state: Optional. State dictionary used for saving required training states in the client code. client_state: Optional. State dictionary used for saving required training states in the client code.
""" """
#This is to make sure the checkpoint names are created without collision # This is to make sure the checkpoint names are created without collision
#There seems to be issue creating them in parallel # There seems to be issue creating them in parallel
self._create_checkpoint_files(save_dir, tag)
if self.save_non_zero_checkpoint: if self.save_non_zero_checkpoint:
self._create_checkpoint_file(save_dir, tag, False)
self._save_checkpoint(save_dir, tag, client_state=client_state) self._save_checkpoint(save_dir, tag, client_state=client_state)
if self.save_zero_checkpoint: if self.save_zero_checkpoint:
self._create_zero_checkpoint_files(save_dir, tag)
self._save_zero_checkpoint(save_dir, tag) self._save_zero_checkpoint(save_dir, tag)
return True return True
def _create_checkpoint_files(self, save_dir, tag): def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):
#checkpoint files are created sequentially name_function = self._get_zero_ckpt_name if zero_checkpoint else self._get_ckpt_name
for rank in range(self.world_size):
if rank == self.global_rank:
try: try:
if self.save_non_zero_checkpoint: checkpoint_name = name_function(save_dir, tag)
checkpoint_name = self._get_ckpt_name(save_dir, tag)
self._ensure_directory_exists(checkpoint_name)
if self.save_zero_checkpoint:
checkpoint_name = self._get_zero_ckpt_name(save_dir, tag)
self._ensure_directory_exists(checkpoint_name) self._ensure_directory_exists(checkpoint_name)
except: except:
logger.error( logger.error(f'Failed Saving model checkpoint to {save_dir} with tag {tag}')
f'Failed Saving model checkpoint to {save_dir} with tag {tag}')
return False return False
return True
def _create_zero_checkpoint_files(self, save_dir, tag):
success = True
# zero checkpoint files are created sequentially
for rank in range(self.world_size):
if rank == self.global_rank:
success = self._create_checkpoint_file(save_dir, tag, True)
dist.barrier() dist.barrier()
return success
def _save_checkpoint(self, save_dir, tag, client_state={}): def _save_checkpoint(self, save_dir, tag, client_state={}):
save_path = self._get_ckpt_name(save_dir, tag) save_path = self._get_ckpt_name(save_dir, tag)
#self._ensure_directory_exists(save_path) # self._ensure_directory_exists(save_path)
state = { state = {
'module': 'module':
...@@ -1321,7 +1340,7 @@ class DeepSpeedLight(Module): ...@@ -1321,7 +1340,7 @@ class DeepSpeedLight(Module):
def _save_zero_checkpoint(self, save_path, tag): def _save_zero_checkpoint(self, save_path, tag):
zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag) zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag)
#self._ensure_directory_exists(zero_checkpoint_name) # self._ensure_directory_exists(zero_checkpoint_name)
zero_sd = {'optimizer_state_dict': self.optimizer.state_dict()} zero_sd = {'optimizer_state_dict': self.optimizer.state_dict()}
torch.save(zero_sd, zero_checkpoint_name) torch.save(zero_sd, zero_checkpoint_name)
logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name)) logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment