Unverified Commit 3b19c031 authored by arun99481's avatar arun99481 Committed by GitHub
Browse files

updating gather function with gather_for_metrics in run_wav2vec2_pretraining (#18877)


Co-authored-by: default avatarArun Rajaram <arunrajaram@Aruns-MacBook-Pro.local>
parent 734b7e2a
...@@ -596,7 +596,7 @@ def main(): ...@@ -596,7 +596,7 @@ def main():
# make sure that `num_losses` is summed for distributed training # make sure that `num_losses` is summed for distributed training
# and average gradients over losses of all devices # and average gradients over losses of all devices
if accelerator.state.num_processes > 1: if accelerator.state.num_processes > 1:
num_losses = accelerator.gather(num_losses).sum() num_losses = accelerator.gather_for_metrics(num_losses).sum()
gradient_multiplier = accelerator.state.num_processes / num_losses gradient_multiplier = accelerator.state.num_processes / num_losses
multiply_grads(model.module.parameters(), gradient_multiplier) multiply_grads(model.module.parameters(), gradient_multiplier)
else: else:
...@@ -647,10 +647,10 @@ def main(): ...@@ -647,10 +647,10 @@ def main():
outputs.diversity_loss.detach() outputs.diversity_loss.detach()
if accelerator.state.num_processes > 1: if accelerator.state.num_processes > 1:
loss = accelerator.gather(loss).sum() loss = accelerator.gather_for_metrics(loss).sum()
outputs.contrastive_loss = accelerator.gather(outputs.contrastive_loss).sum() outputs.contrastive_loss = accelerator.gather_for_metrics(outputs.contrastive_loss).sum()
outputs.diversity_loss = accelerator.gather(outputs.diversity_loss).sum() outputs.diversity_loss = accelerator.gather_for_metrics(outputs.diversity_loss).sum()
percent_masked = accelerator.gather(percent_masked).sum() percent_masked = accelerator.gather_for_metrics(percent_masked).sum()
train_logs = { train_logs = {
"loss": (loss * args.gradient_accumulation_steps) / num_losses, "loss": (loss * args.gradient_accumulation_steps) / num_losses,
...@@ -713,7 +713,7 @@ def main(): ...@@ -713,7 +713,7 @@ def main():
# sum over devices in multi-processing # sum over devices in multi-processing
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
val_logs = {k: accelerator.gather(v).sum() for k, v in val_logs.items()} val_logs = {k: accelerator.gather_for_metrics(v).sum() for k, v in val_logs.items()}
val_logs = {k: v / val_logs["val_num_losses"] for k, v in val_logs.items()} val_logs = {k: v / val_logs["val_num_losses"] for k, v in val_logs.items()}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment