Commit c671de3e authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Move division of loss tensor by number of microbatches to training.py

parent 69a546be
......@@ -294,6 +294,8 @@ def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_g
def forward_step_with_communication(forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers):
args = get_args()
if not mpu.is_pipeline_first_stage():
timers('forward-recv').start()
input_tensor, _ = communicate(
......@@ -312,7 +314,7 @@ def forward_step_with_communication(forward_step_func, data_iterator, model,
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss
output_tensor = loss / args.num_microbatches_in_minibatch
losses_reduced.append(loss_reduced)
else:
timers('forward-send').start()
......@@ -328,7 +330,6 @@ def forward_step_with_communication(forward_step_func, data_iterator, model,
def backward_step_with_communication(optimizer, model, input_tensors, output_tensors, timers):
"""Backward step."""
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
......@@ -364,6 +365,8 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
input_tensor, last_microbatch,
input_tensors, output_tensors,
losses_reduced, timers):
args = get_args()
# Forward model for one step.
timers('forward-compute').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor)
......@@ -371,7 +374,7 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss
output_tensor = loss / args.num_microbatches_in_minibatch
output_tensor_grad = None
losses_reduced.append(loss_reduced)
else:
......@@ -418,7 +421,7 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
for i in range(args.num_microbatches_in_minibatch):
timers('forward-compute').start()
loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor=None)
output_tensor = loss
output_tensor = loss / args.num_microbatches_in_minibatch
losses_reduced.append(loss_reduced)
timers('forward-compute').stop()
......@@ -571,7 +574,7 @@ def train_step(forward_step_func, data_iterator,
loss_reduced = {}
for key in losses_reduced[0]:
losses_reduced_for_key = [x[key] for x in losses_reduced]
loss_reduced[key] = sum(losses_reduced_for_key)
loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
return loss_reduced, skipped_iter
return {}, skipped_iter
......@@ -770,7 +773,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
_, loss_dict = output_tensor
# Reduce across processes.
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \
loss_dict[key]
else:
communicate(
......
......@@ -118,8 +118,7 @@ def forward_step(data_iterator, model, input_tensor):
lm_loss_ = lm_loss_.float()
loss_mask = loss_mask.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / (
loss_mask.sum() * args.num_microbatches_in_minibatch)
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
loss = lm_loss + sop_loss
......
......@@ -110,8 +110,7 @@ def forward_step(data_iterator, model, input_tensor):
if mpu.is_pipeline_last_stage():
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / (
loss_mask.sum() * args.num_microbatches_in_minibatch)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment