"vscode:/vscode.git/clone" did not exist on "7d7edf6d37576fb6eda65db6db43fda54a7f06ba"
Unverified Commit 9c80f6d3 authored by Chaitanya Sri Krishna Lolla's avatar Chaitanya Sri Krishna Lolla Committed by GitHub
Browse files

Enable sync batchnorm extension. (#27)

* Enable sync batchnorm

* enable syncbn properly

* update the unit tests

* update tests

* update conditions for welford_merge_element

* updated conditions based on comments.
parent 33a3a667
......@@ -11,6 +11,11 @@
#include "type_shim.h"
#include "compat.h"
#if defined __HIP_PLATFORM_HCC__
#define SHFL_DOWN __shfl_down
#else
#define SHFL_DOWN __shfl_down_sync
#endif
__device__ __forceinline__ int lastpow2(int n)
{
......@@ -47,7 +52,7 @@ __device__ __forceinline__ T warp_reduce_sum(T val)
{
#pragma unroll
for(int i = WARP_SIZE/2; i > 0; i >>= 1)
val = val + __shfl_down_sync(0xffffffff, val, i);
val = val + SHFL_DOWN(0xffffffff, val, i);
return val;
}
......@@ -129,10 +134,14 @@ __device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num)
{
#pragma unroll
for(int i = WARP_SIZE/2; i > 0; i >>= 1) {
auto num_new = __shfl_down_sync(0xffffffff, num, i);
auto mean_new = __shfl_down_sync(0xffffffff, mean, i);
auto m2n_new = __shfl_down_sync(0xffffffff, m2n, i);
auto num_new = SHFL_DOWN(0xffffffff, num, i);
auto mean_new = SHFL_DOWN(0xffffffff, mean, i);
auto m2n_new = SHFL_DOWN(0xffffffff, m2n, i);
#if defined __HIP_PLATFORM_HCC__
welford_merge_element<T, int>(num, mean, m2n, num_new, mean_new, m2n_new);
#else
welford_merge_element(num, mean, m2n, num_new, mean_new, m2n_new);
#endif
}
}
......
......@@ -189,7 +189,12 @@ if "--cuda_ext" in sys.argv:
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
else:
print ("INFO: Skipping syncbn extension.")
print ("INFO: Building syncbn extension.")
ext_modules.append(
CUDAExtension(name='syncbn',
sources=['csrc/syncbn.cpp',
'csrc/hip/welford.hip'],
extra_compile_args=['-O3'] + version_dependent_macros))
if not is_rocm_pytorch:
......
......@@ -6,8 +6,8 @@ export WORLD_SIZE=2
# Test with opt_level="O2"
echo "running opt_level O2"
python3.6 -m torch.distributed.launch --nproc_per_node=2 amp_master_params.py --opt_level "O2"
python3.6 compare.py
python3.6 -m torch.distributed.launch --nproc_per_node=2 amp_master_params/amp_master_params.py --opt_level "O2"
python3.6 amp_master_params/compare.py
# delete the model files
echo -e "O2 test completed. Deleting model files\n"
......@@ -19,8 +19,8 @@ rm rank1master.pth
# Test with opt_level="O5"
echo "running opt_level O5"
python3.6 -m torch.distributed.launch --nproc_per_node=2 amp_master_params.py --opt_level "O5"
python3.6 compare.py
python3.6 -m torch.distributed.launch --nproc_per_node=2 amp_master_params/amp_master_params.py --opt_level "O5"
python3.6 amp_master_params/compare.py
# delete the model files
echo "O5 test completed. Deleting model files"
......@@ -28,3 +28,10 @@ rm rank0model.pth
rm rank1model.pth
rm rank0master.pth
rm rank1master.pth
## Run the Sync BN Tests.
echo "Running syncbn tests"
python3.6 -m torch.distributed.launch --nproc_per_node=2 synced_batchnorm/two_gpu_test_different_batch_size.py --apex
echo "Running syncbn python only tests"
python3.6 synced_batchnorm/python_single_gpu_unit_test.py
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment