Commit c671de3e authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Move division of loss tensor by number of microbatches to training.py

parent 69a546be
...@@ -294,6 +294,8 @@ def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_g ...@@ -294,6 +294,8 @@ def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_g
def forward_step_with_communication(forward_step_func, data_iterator, model, def forward_step_with_communication(forward_step_func, data_iterator, model,
input_tensors, output_tensors, input_tensors, output_tensors,
losses_reduced, timers): losses_reduced, timers):
args = get_args()
if not mpu.is_pipeline_first_stage(): if not mpu.is_pipeline_first_stage():
timers('forward-recv').start() timers('forward-recv').start()
input_tensor, _ = communicate( input_tensor, _ = communicate(
...@@ -312,7 +314,7 @@ def forward_step_with_communication(forward_step_func, data_iterator, model, ...@@ -312,7 +314,7 @@ def forward_step_with_communication(forward_step_func, data_iterator, model,
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor loss, loss_reduced = output_tensor
output_tensor = loss output_tensor = loss / args.num_microbatches_in_minibatch
losses_reduced.append(loss_reduced) losses_reduced.append(loss_reduced)
else: else:
timers('forward-send').start() timers('forward-send').start()
...@@ -328,7 +330,6 @@ def forward_step_with_communication(forward_step_func, data_iterator, model, ...@@ -328,7 +330,6 @@ def forward_step_with_communication(forward_step_func, data_iterator, model,
def backward_step_with_communication(optimizer, model, input_tensors, output_tensors, timers): def backward_step_with_communication(optimizer, model, input_tensors, output_tensors, timers):
"""Backward step."""
input_tensor = input_tensors.pop(0) input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
...@@ -364,6 +365,8 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat ...@@ -364,6 +365,8 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
input_tensor, last_microbatch, input_tensor, last_microbatch,
input_tensors, output_tensors, input_tensors, output_tensors,
losses_reduced, timers): losses_reduced, timers):
args = get_args()
# Forward model for one step. # Forward model for one step.
timers('forward-compute').start() timers('forward-compute').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor) output_tensor = forward_step_func(data_iterator, model, input_tensor)
...@@ -371,7 +374,7 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat ...@@ -371,7 +374,7 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor loss, loss_reduced = output_tensor
output_tensor = loss output_tensor = loss / args.num_microbatches_in_minibatch
output_tensor_grad = None output_tensor_grad = None
losses_reduced.append(loss_reduced) losses_reduced.append(loss_reduced)
else: else:
...@@ -418,7 +421,7 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model, ...@@ -418,7 +421,7 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
for i in range(args.num_microbatches_in_minibatch): for i in range(args.num_microbatches_in_minibatch):
timers('forward-compute').start() timers('forward-compute').start()
loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor=None) loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor=None)
output_tensor = loss output_tensor = loss / args.num_microbatches_in_minibatch
losses_reduced.append(loss_reduced) losses_reduced.append(loss_reduced)
timers('forward-compute').stop() timers('forward-compute').stop()
...@@ -571,7 +574,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -571,7 +574,7 @@ def train_step(forward_step_func, data_iterator,
loss_reduced = {} loss_reduced = {}
for key in losses_reduced[0]: for key in losses_reduced[0]:
losses_reduced_for_key = [x[key] for x in losses_reduced] losses_reduced_for_key = [x[key] for x in losses_reduced]
loss_reduced[key] = sum(losses_reduced_for_key) loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
return loss_reduced, skipped_iter return loss_reduced, skipped_iter
return {}, skipped_iter return {}, skipped_iter
...@@ -770,7 +773,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -770,7 +773,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
_, loss_dict = output_tensor _, loss_dict = output_tensor
# Reduce across processes. # Reduce across processes.
for key in loss_dict: for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + \ total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \
loss_dict[key] loss_dict[key]
else: else:
communicate( communicate(
......
...@@ -118,8 +118,7 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -118,8 +118,7 @@ def forward_step(data_iterator, model, input_tensor):
lm_loss_ = lm_loss_.float() lm_loss_ = lm_loss_.float()
loss_mask = loss_mask.float() loss_mask = loss_mask.float()
lm_loss = torch.sum( lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / ( lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
loss_mask.sum() * args.num_microbatches_in_minibatch)
loss = lm_loss + sop_loss loss = lm_loss + sop_loss
......
...@@ -110,8 +110,7 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -110,8 +110,7 @@ def forward_step(data_iterator, model, input_tensor):
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
losses = output_tensor.float() losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float() loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / ( loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
loss_mask.sum() * args.num_microbatches_in_minibatch)
# Reduce loss for logging. # Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss]) averaged_loss = average_losses_across_data_parallel_group([loss])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment