Unverified Commit 20557f70 authored by Shaden Smith's avatar Shaden Smith Committed by GitHub
Browse files

Fix ThroughputTimer with hybrid parallelism. (#171)

parent a76572dc
......@@ -142,18 +142,18 @@ class DeepSpeedLight(Module):
self._init_distributed(dist_init_required)
# Configure distributed model
self._configure_distributed_model(model)
# Throughput timer
self.tput_timer = ThroughputTimer(
batch_size=self.train_micro_batch_size_per_gpu(),
num_workers=self.world_size,
num_workers=self.dp_world_size,
monitor_memory=False)
self.training_dataloader = self.deepspeed_io(
training_data) if training_data else None
# Configure distributed model
self._configure_distributed_model(model)
# Configure optimizer and scheduler
self.optimizer = None
self.lr_scheduler = None
......@@ -324,17 +324,19 @@ class DeepSpeedLight(Module):
def _configure_checkpointing(self, dist_init_required):
dp_rank = torch.distributed.get_rank(
) if self.mpu is None else self.mpu.get_data_parallel_rank()
dp_rank = self.global_rank
if self.mpu:
dp_rank = self.mpu.get_data_parallel_rank()
#only the first data parallel process needs to store the model checkpoint
self.save_non_zero_checkpoint = True if dp_rank == 0 else False
self.save_non_zero_checkpoint = (dp_rank == 0)
if self.zero_optimization():
pp_rank = torch.distributed.get_rank(group=self.optimizer.dp_process_group)
#only the first parameter parallel process needs to store the optimizer state checkpoints for zero
self.save_zero_checkpoint = True if pp_rank == dp_rank else False
# Only the first parameter parallel process needs to store the
# optimizer state checkpoints for zero
self.save_zero_checkpoint = (pp_rank == dp_rank)
def _scheduler_from_config(self, optimizer):
scheduler_name = self.scheduler_name()
......@@ -621,11 +623,12 @@ class DeepSpeedLight(Module):
allreduce_gradients: If this is False, then gradient averaging will be skipped. Default is True.
"""
if self.is_gradient_accumulation_boundary() and self.tensorboard_enabled(
) and torch.distributed.get_rank(
) == 0: # deepspeed tensorboard support for loss
# Log training Loss
if self.tensorboard_enabled():
if self.is_gradient_accumulation_boundary():
if self.global_rank == 0:
self.sample_count += (self.train_micro_batch_size_per_gpu() *
torch.distributed.get_world_size() *
self.dp_world_size *
self.gradient_accumulation_steps())
self.summary_events = [
(f'Train/Samples/train_loss',
......@@ -712,8 +715,10 @@ class DeepSpeedLight(Module):
self.tput_timer.stop(report_progress)
if self.is_gradient_accumulation_boundary() and self.tensorboard_enabled(
) and torch.distributed.get_rank() == 0: # deepspeed tensorboard support for lr
# Log learning rate
if self.tensorboard_enabled():
if self.is_gradient_accumulation_boundary():
if self.global_rank == 0:
self.summary_events = [(f'Train/Samples/lr',
self.get_lr()[0],
self.sample_count)]
......@@ -731,9 +736,10 @@ class DeepSpeedLight(Module):
'backward_allreduce_microstep',
'step_microstep'
])
# Log timing
if self.tensorboard_enabled():
if self.is_gradient_accumulation_boundary():
if self.tensorboard_enabled() and torch.distributed.get_rank(
) == 0: # this is done before the log because log resets timers
if self.global_rank == 0:
self.summary_events = [(f'Train/Samples/elapsed_time_ms_forward', self.timers('forward').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_backward', self.timers('backward').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_backward_inner', self.timers('backward_inner').elapsed(reset=False) * 1000.0, self.sample_count), \
......@@ -870,7 +876,7 @@ class DeepSpeedLight(Module):
return csr
def csr_all_gather(self, value):
my_size = torch.LongTensor([value.size()[0]]).cuda()
my_size = torch.LongTensor([value.size()[0]]).to(self.device)
all_sizes = self.all_gather_scalar(my_size)
max_size = torch.cat(all_sizes).max()
fill_size = (max_size - my_size)
......@@ -879,22 +885,22 @@ class DeepSpeedLight(Module):
if value.dim() == 1:
if fill_size > 0:
value = torch.cat([value, value.new_zeros(fill_size)])
tensor_list = [
value.new_zeros(max_size) for _ in range(dist.get_world_size())
]
tensor_list = [value.new_zeros(max_size) for _ in range(self.dp_world_size)]
else:
if fill_size > 0:
value = torch.cat([value, value.new_zeros(fill_size, value.size()[1])])
tensor_list = [
value.new_zeros(max_size,
value.size()[1]) for _ in range(dist.get_world_size())
value.size()[1]) for _ in range(self.dp_world_size)
]
dist.all_gather(tensor_list, value, group=self.data_parallel_group)
tensors = []
for dev_idx, t in enumerate(tensor_list):
size = all_sizes[dev_idx][0]
tensors.append(t.index_select(0, torch.LongTensor(range(size)).cuda()))
tensors.append(
t.index_select(0,
torch.LongTensor(range(size)).to(self.device)))
return tensors
......@@ -1036,8 +1042,8 @@ class DeepSpeedLight(Module):
def _create_checkpoint_files(self, save_dir, tag):
#checkpoint files are created sequentially
for rank in range(dist.get_world_size()):
if rank == dist.get_rank():
for rank in range(self.world_size):
if rank == self.global_rank:
try:
if self.save_non_zero_checkpoint:
checkpoint_name = self._get_ckpt_name(save_dir, tag)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment