Commit 1f6d6678 authored by zhuwenwen's avatar zhuwenwen
Browse files

support dtk2304

parent dd424ec5
// part of code modified from https://github.com/NVIDIA/apex // part of code modified from https://github.com/NVIDIA/apex
//#include <cooperative_groups.h> //#include <cooperative_groups.h>
#include <hip/hsa_detail/hip_cooperative_groups.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <hip/hip_cooperative_groups.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <THC/THCDeviceUtils.cuh> #include <THC/THCDeviceUtils.cuh>
...@@ -48,15 +48,15 @@ __inline__ __device__ void WelfordWarpAllReduce(float thread_mean, float thread_ ...@@ -48,15 +48,15 @@ __inline__ __device__ void WelfordWarpAllReduce(float thread_mean, float thread_
*m2 = thread_m2; *m2 = thread_m2;
*count = thread_count; *count = thread_count;
for (int mask = 1; mask < 32; mask *= 2) { for (int mask = 1; mask < 32; mask *= 2) {
float b_mean = __shfl_down_sync(0xffffffff, *mean, mask); float b_mean = __shfl_down(*mean, mask);
float b_m2 = __shfl_down_sync(0xffffffff, *m2, mask); float b_m2 = __shfl_down(*m2, mask);
float b_count = __shfl_down_sync(0xffffffff, *count, mask); float b_count = __shfl_down(*count, mask);
WelfordOnline(b_mean, b_m2, b_count, mean, m2, count); WelfordOnline(b_mean, b_m2, b_count, mean, m2, count);
} }
*mean = __shfl_sync(0xffffffff, *mean, 0, 32); *mean = __shfl(*mean, 0, 32);
*m2 = __shfl_sync(0xffffffff, *m2, 0, 32); *m2 = __shfl(*m2, 0, 32);
*count = __shfl_sync(0xffffffff, *count, 0, 32); *count = __shfl(*count, 0, 32);
} }
template <typename T> template <typename T>
......
...@@ -17,14 +17,14 @@ ...@@ -17,14 +17,14 @@
__inline__ __device__ float WarpAllReduceMax(float val) { __inline__ __device__ float WarpAllReduceMax(float val) {
for (int mask = 1; mask < 32; mask *= 2) { for (int mask = 1; mask < 32; mask *= 2) {
val = max(val, __shfl_xor_sync(0xffffffff, val, mask)); val = max(val, __shfl_xor(val, mask));
} }
return val; return val;
} }
__inline__ __device__ float WarpAllReduceSum(float val) { __inline__ __device__ float WarpAllReduceSum(float val) {
for (int mask = 1; mask < 32; mask *= 2) { for (int mask = 1; mask < 32; mask *= 2) {
val += __shfl_xor_sync(0xffffffff, val, mask); val += __shfl_xor(val, mask);
} }
return val; return val;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment