[syncBN] (#48)
* [syncBN]
added syncBN in native pure python apex
added fused cuda kernels used for sync BN. Using welford for mean/var
optional installation using 'python setup.py install --cuda_ext'
added unit test with side to side comparison between apex sync BN with
PyTorch BN. Notice that for pytorch BN implementation, because of
numerical issue for mean/var, the output will be slightly off.
* [syncBN PR]
added fp16 support
addressing review comments on:
1. updating last pow 2
2. look for import error when importing syncBN kernel
* [syncBN PR]
added convert function to insert SyncBatchNorm
refactored some kernel code
* fixing type issue (fp16/fp32/fp64)
added Kahan summation
editing unit test to use pytorch primitive ops with double, passing reasonable tests now
* updating tensor creation calls
* fixing the all_reduce contiguous tensor
* transposed all reduce results
* [syncBN]
support fp16 input & fp32 layer for apex fp16
partially fixing launch configs
enabling imagenet example to run with --sync_bn
* [syncBN PR]
Documentation added
* adjusting README
* adjusting again
* added some doc to imagenet example
* [syncBN]
warp-level reduction
bug fix: warp reduction logic updated. check for dummy element to avoid nan.
improved launch config for better reduction kernels. Further improvements
would be to increase grid size.
* [syncBN]
fixing undefined behavior in __shfl_down_sync from divergent threads in warp
reduction.
changing at::native::empty to at::empty (upstream comments)
Showing
csrc/syncbn.cpp
0 → 100644
csrc/welford.cu
0 → 100644
Please register or sign in to comment