Unverified Commit 8f1b5782 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[Misc] Add dist.barrier to fix low accuracy issue in the example. (#6544)

parent 0b5abba8
...@@ -278,6 +278,7 @@ def train( ...@@ -278,6 +278,7 @@ def train(
dist.reduce(tensor=acc, dst=0) dist.reduce(tensor=acc, dst=0)
total_loss /= step + 1 total_loss /= step + 1
dist.reduce(tensor=total_loss, dst=0) dist.reduce(tensor=total_loss, dst=0)
dist.barrier()
epoch_end = time.time() epoch_end = time.time()
if rank == 0: if rank == 0:
...@@ -349,6 +350,7 @@ def run(rank, world_size, args, devices, dataset): ...@@ -349,6 +350,7 @@ def run(rank, world_size, args, devices, dataset):
/ world_size / world_size
) )
dist.reduce(tensor=test_acc, dst=0) dist.reduce(tensor=test_acc, dst=0)
dist.barrier()
if rank == 0: if rank == 0:
print(f"Test Accuracy {test_acc.item():.4f}") print(f"Test Accuracy {test_acc.item():.4f}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment