• sarunyap's avatar
    Enable group batch norm (--bnp) on ROCm (only bn_group = 1) (#51) · e57c84e0
    sarunyap authored
    * Enable group batch norm (--bnp) on ROCm (only bn_group = 1)
    
    Enable NHWC group batch norm on a single GPU on ROCm (bn_group = 1).
    The multi-GPU case (bn_group > 1) will be revisited in the future.
    
    The following are the main changes:
    
    1) Use MIOpen data structures/functions in HIP instead of CUDNN
    2) For the warp-level primitive code, we ensure that the code operates
       on 64-thread wide warp instead of 32-thread wide
    3) Disable all the bn_group > 1 paths
    
    Notes:
    
    1) Multi-stream is not tested.
    2) We have not optimized for performance
    
    * Fix bnp hipification
    
    Avoid calling hipify-perl in setup.py and rely on PyTorch's internal
    hipification mechanism.
    
    * Make bnp data pointers contiguous
    
    The contrib group batch norm implementation assumes that all input
    tensors are contiguous.  When non-contiguous tensors are passed to the
    function, it gives a wrong result.  This commit explicitly calls
    .contiguous() to make all input tensors contiguous before accessing
    them.
    
    * Fix HIP lane id in bnp
    
    Fix typo
    
    * Fix ReLU bitmask for HIP in bnp
    
    The ReLU bitmask is derived by using the __ballot function which returns
    a 64-bit value in HIP.  This commit fixes the ReLU bitmask storage size
    and offsets on ROCm.
    
    This patch also fixes the kernel to set ReLU bitmask to 1 when the data
    is less than or equal to zero (not only less than).  Not doing so can
    cause a stability issue.
    
    * Remove multiple of 64 offset for HIP in bnp
    
    The multiple of 64 offset is not necessary.
    
    * Use FP16 intermediate output to determine whether to rectify in bnp
    
    Group batch norm takes FP16 tensors and produces the FP16 output,
    however, all arithmetic operations are done in FP32, thus intermediate
    outputs are in FP32.  For the fusion kernels, ReLU determines the FP32
    intermediate output to decide whether to rectify it.  ReLU must rectify
    the intermediate output if it is less than or "equal" to zero.  There is
    a chance that the intermediate FP32 output is very close to zero, and
    when it is converted to FP16, it becomes zero.  In this case, this
    output is not rectified when it should be.  Since the output is not
    rectified in the forward pass, the gradient is not rectified in the
    backward pass.  This can cause a stability issue.
    
    This patch can have a negative impact on the performance of group batch
    norm as we perform FP32-FP16 conversion multiple times.
    
    * Disable dispatchX ParallelSums in HIP in bnp
    
    dispatchX is not required for the bn_group = 1 case.
    
    * Use traditional load/store for HIP in bnp
    
    The built-in function has a high floating point rounding error.  Thus,
    we replace it with the traditional load/store.  Doing so breaks the
    aligned pointer property in the load/store functions.  We conservatively
    use traditional load/store for all memory access.
    
    * Replace shfl_down with shfl_sync in parallel sums for HIP in bnp
    
    This commit separates the HIP code from the CUDA code in parallel sums
    
    * Remove -U__HIP_NO_HALF_CONVERSIONS__ for HIP in bnp
    
    Since the built-in function is removed, -U__HIP_NO_HALF_CONVERSIONS__ is
    no longer needed.
    
    * Preserve CUDA's ReLU condition path for USE_ADD_RELU in bnp
    
    * Add test for bnp
    
    The test evaluates correctness of batch norm, batch norm + ReLU, and
    batch norm + add + ReLU against the reference implementation.
    
    For the forward activation output, we validate it against the PyTorch's
    implementation.  The group batch norm activation output must be allclose
    with the PyTorch activation output for the test to pass.
    
    For the backward gradient output, we validate it against the Python
    implementation.  Due to the floating point rounding error in the batch
    norm implementation, the group batch norm gradient output might not be
    allclose with the Python implementation output when ReLU is being used
    although the majority of the elements are very close to each other.
    Thus, we use the norm difference threshold to determine whether the test
    is passed or failed instead of allclose.
    
    * Use the warp size variable than hard coding the warp size in bnp
    
    Use C10_WARP_SIZE from c10/macros/Macros.h in the host functions and use
    warpSize in the device kernels instead of hard coding the warp size.
    e57c84e0
setup.py 30.7 KB