Unverified Commit 8f1b5782 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[Misc] Add dist.barrier to fix low accuracy issue in the example. (#6544)

parent 0b5abba8
......@@ -278,6 +278,7 @@ def train(
dist.reduce(tensor=acc, dst=0)
total_loss /= step + 1
dist.reduce(tensor=total_loss, dst=0)
dist.barrier()
epoch_end = time.time()
if rank == 0:
......@@ -349,6 +350,7 @@ def run(rank, world_size, args, devices, dataset):
/ world_size
)
dist.reduce(tensor=test_acc, dst=0)
dist.barrier()
if rank == 0:
print(f"Test Accuracy {test_acc.item():.4f}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment