Commit cd499559 authored by Jared Casper's avatar Jared Casper
Browse files

Revert incorrect fix.

parent 2a34e0ec
...@@ -84,7 +84,7 @@ def average_losses_across_data_parallel_group(losses): ...@@ -84,7 +84,7 @@ def average_losses_across_data_parallel_group(losses):
[loss.clone().detach().view(1) for loss in losses]) [loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(averaged_losses, torch.distributed.all_reduce(averaged_losses,
group=mpu.get_data_parallel_group()) group=mpu.get_data_parallel_group())
averaged_losses = averaged_losses / averaged_losses = averaged_losses / \
torch.distributed.get_world_size(group=mpu.get_data_parallel_group()) torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
return averaged_losses return averaged_losses
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment