• lly-zero-one's avatar
    Optimize the sync batchnorm by batching the communication (#980) · 8a1ed9e8
    lly-zero-one authored
    In this PR, we mainly tried to optimize the performance of Syncatchnorm and also fixed one potential issue in the welford_parallel kernel implementation.
    
    For performance improvement, we batched the mean/var/count all_gather communication together and sent it once in the forward path
    We also batch the all_reduce in backward path
    We add the contiguous call on the input of welford_parallel kernel.
    If there is any standard perf benchmark, I would be happy to run it.
    8a1ed9e8
welford.cu 53.3 KB