"docs/vscode:/vscode.git/clone" did not exist on "ca4b86c564a735078aadb0bfb0e3d529735f2c79"
Commit 3d7194c4 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Divide gradient by number of microbatches in minibatch

parent a6756bf8
......@@ -554,8 +554,7 @@ def train_step(forward_step_func, data_iterator,
loss_reduced = {}
for key in losses_reduced[0]:
losses_reduced_for_key = [x[key] for x in losses_reduced]
loss_reduced[key] = sum(losses_reduced_for_key) / \
len(losses_reduced_for_key)
loss_reduced[key] = sum(losses_reduced_for_key)
return loss_reduced, skipped_iter
return {}, skipped_iter
......
......@@ -118,7 +118,8 @@ def forward_step(data_iterator, model, input_tensor):
lm_loss_ = lm_loss_.float()
loss_mask = loss_mask.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
lm_loss_.view(-1) * loss_mask.reshape(-1)) / (
loss_mask.sum() * args.num_microbatches_in_minibatch)
loss = lm_loss + sop_loss
......
......@@ -110,7 +110,8 @@ def forward_step(data_iterator, model, input_tensor):
if mpu.is_pipeline_last_stage():
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
loss = torch.sum(losses.view(-1) * loss_mask) / (
loss_mask.sum() * args.num_microbatches_in_minibatch)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment