Commit 1f6d6678 authored by zhuwenwen's avatar zhuwenwen
Browse files

support dtk2304

parent dd424ec5
// part of code modified from https://github.com/NVIDIA/apex
//#include <cooperative_groups.h>
#include <hip/hsa_detail/hip_cooperative_groups.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <hip/hip_cooperative_groups.h>
#include <torch/extension.h>
#include <THC/THCDeviceUtils.cuh>
......@@ -48,15 +48,15 @@ __inline__ __device__ void WelfordWarpAllReduce(float thread_mean, float thread_
*m2 = thread_m2;
*count = thread_count;
for (int mask = 1; mask < 32; mask *= 2) {
float b_mean = __shfl_down_sync(0xffffffff, *mean, mask);
float b_m2 = __shfl_down_sync(0xffffffff, *m2, mask);
float b_count = __shfl_down_sync(0xffffffff, *count, mask);
float b_mean = __shfl_down(*mean, mask);
float b_m2 = __shfl_down(*m2, mask);
float b_count = __shfl_down(*count, mask);
WelfordOnline(b_mean, b_m2, b_count, mean, m2, count);
}
*mean = __shfl_sync(0xffffffff, *mean, 0, 32);
*m2 = __shfl_sync(0xffffffff, *m2, 0, 32);
*count = __shfl_sync(0xffffffff, *count, 0, 32);
*mean = __shfl(*mean, 0, 32);
*m2 = __shfl(*m2, 0, 32);
*count = __shfl(*count, 0, 32);
}
template <typename T>
......
......@@ -17,14 +17,14 @@
__inline__ __device__ float WarpAllReduceMax(float val) {
for (int mask = 1; mask < 32; mask *= 2) {
val = max(val, __shfl_xor_sync(0xffffffff, val, mask));
val = max(val, __shfl_xor(val, mask));
}
return val;
}
__inline__ __device__ float WarpAllReduceSum(float val) {
for (int mask = 1; mask < 32; mask *= 2) {
val += __shfl_xor_sync(0xffffffff, val, mask);
val += __shfl_xor(val, mask);
}
return val;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment