Commit 93ecbc82 authored by yuguo's avatar yuguo
Browse files
parents bb8cf71b 736e8f8b
...@@ -352,6 +352,39 @@ void __launch_bounds__(THREADS_PER_BLOCK) bias_gradient_kernel(const Tin* in, fl ...@@ -352,6 +352,39 @@ void __launch_bounds__(THREADS_PER_BLOCK) bias_gradient_kernel(const Tin* in, fl
atomicAdd(&out[col_idx], local_sum); atomicAdd(&out[col_idx], local_sum);
} }
constexpr int kColwiseReduceTileSize = 32;
template <typename T>
__inline__ __device__ T WarpReduceSum(T val, int max = 32) {
for (int offset = max; offset > 0; offset >>= 1) {
val += __shfl_down(val, offset);
}
return val;
}
template <typename InputType>
__launch_bounds__(1024) __global__
void bias_gradient_kernel_v2(float *dst, const InputType *src, int M, int N) {
__shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize];
const int j = blockIdx.x * blockDim.x + threadIdx.x;
float grad_sum = 0.f;
if (j < N) {
for (int i = threadIdx.y; i < M; i += blockDim.y) {
grad_sum += static_cast<float>(src[i * N + j]);
}
}
g_shared[threadIdx.y][threadIdx.x] = grad_sum;
__syncthreads();
float sum = g_shared[threadIdx.x][threadIdx.y];
sum = WarpReduceSum<float>(sum, kColwiseReduceTileSize / 2);
if (threadIdx.x == 0) {
const int j = blockIdx.x * blockDim.x + threadIdx.y;
if (j < N) {
dst[j] = static_cast<float>(sum);
}
}
}
template <typename Tin> template <typename Tin>
void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool stream_order_alloc, hipStream_t stream) { void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool stream_order_alloc, hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
...@@ -364,7 +397,9 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool ...@@ -364,7 +397,9 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
}else{ }else{
NVTE_CHECK_CUDA( hipMemsetAsync(out, 0, n*sizeof(float), stream) ); NVTE_CHECK_CUDA( hipMemsetAsync(out, 0, n*sizeof(float), stream) );
} }
hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n); // hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n);
int B =(n - 1) / kColwiseReduceTileSize + 1;
bias_gradient_kernel_v2<Tin><<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(out, in, m, n);
} }
} // namespace detail } // namespace detail
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment