Optimize the sync batchnorm by batching the communication (#980)
In this PR, we mainly tried to optimize the performance of Syncatchnorm and also fixed one potential issue in the welford_parallel kernel implementation. For performance improvement, we batched the mean/var/count all_gather communication together and sent it once in the forward path We also batch the all_reduce in backward path We add the contiguous call on the input of welford_parallel kernel. If there is any standard perf benchmark, I would be happy to run it.
Showing
Please register or sign in to comment