Unverified Commit 1cb3da87 authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

Optimize layer normalization for AMD GPUs (#66)

* Optimize fused layer normalization for MI100

* Optimize cuComputePartGradGammaBeta for AMD GPUs
parent 151d150b
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "type_shim.h" #include "type_shim.h"
template<typename U> __device__ template<typename U> __device__
void cuWelfordOnlineSum( void cuWelfordOnlineSum(
const U curr, const U curr,
...@@ -56,7 +57,8 @@ void cuWelfordMuSigma2( ...@@ -56,7 +57,8 @@ void cuWelfordMuSigma2(
const int i1, const int i1,
U& mu, U& mu,
U& sigma2, U& sigma2,
U* buf) U* buf,
const int GPU_WARP_SIZE)
{ {
// Assumptions: // Assumptions:
// 1) blockDim.x == warpSize // 1) blockDim.x == warpSize
...@@ -86,12 +88,12 @@ void cuWelfordMuSigma2( ...@@ -86,12 +88,12 @@ void cuWelfordMuSigma2(
cuWelfordOnlineSum<U>(curr,mu,sigma2,count); cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
} }
// intra-warp reductions // intra-warp reductions
for (int l = 0; l <= 4; ++l) { #pragma unroll
int srcLaneB = (threadIdx.x+(1<<l))&31; for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) {
U muB = WARP_SHFL(mu, srcLaneB, 32); U muB = WARP_SHFL_DOWN(mu, stride);
U countB = WARP_SHFL(count, srcLaneB, 32); U countB = WARP_SHFL_DOWN(count, stride);
U sigma2B = WARP_SHFL(sigma2, srcLaneB, 32); U sigma2B = WARP_SHFL_DOWN(sigma2, stride);
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count); cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
} }
// threadIdx.x == 0 has correct values for each warp // threadIdx.x == 0 has correct values for each warp
// inter-warp reductions // inter-warp reductions
...@@ -126,8 +128,8 @@ void cuWelfordMuSigma2( ...@@ -126,8 +128,8 @@ void cuWelfordMuSigma2(
sigma2 = ubuf[1]/U(n2); sigma2 = ubuf[1]/U(n2);
// don't care about final value of count, we know count == n2 // don't care about final value of count, we know count == n2
} else { } else {
mu = WARP_SHFL(mu, 0, 32); mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2/U(n2), 0, 32); sigma2 = WARP_SHFL(sigma2 / U(n2), 0);
} }
} }
} }
...@@ -140,7 +142,8 @@ void cuWelfordMuSigma2( ...@@ -140,7 +142,8 @@ void cuWelfordMuSigma2(
const int i1, const int i1,
float& mu, float& mu,
float& sigma2, float& sigma2,
float* buf) float* buf,
const int GPU_WARP_SIZE)
{ {
// Assumptions: // Assumptions:
// 1) blockDim.x == warpSize // 1) blockDim.x == warpSize
...@@ -181,12 +184,12 @@ void cuWelfordMuSigma2( ...@@ -181,12 +184,12 @@ void cuWelfordMuSigma2(
cuWelfordOnlineSum(curr,mu,sigma2,count); cuWelfordOnlineSum(curr,mu,sigma2,count);
} }
// intra-warp reductions // intra-warp reductions
for (int l = 0; l <= 4; ++l) { #pragma unroll
int srcLaneB = (threadIdx.x+(1<<l))&31; for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) { // TODO
float muB = WARP_SHFL(mu, srcLaneB, 32); float muB = WARP_SHFL_DOWN(mu, stride);
float countB = WARP_SHFL(count, srcLaneB, 32); float countB = WARP_SHFL_DOWN(count, stride);
float sigma2B = WARP_SHFL(sigma2, srcLaneB, 32); float sigma2B = WARP_SHFL_DOWN(sigma2, stride);
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
} }
// threadIdx.x == 0 has correct values for each warp // threadIdx.x == 0 has correct values for each warp
// inter-warp reductions // inter-warp reductions
...@@ -221,8 +224,8 @@ void cuWelfordMuSigma2( ...@@ -221,8 +224,8 @@ void cuWelfordMuSigma2(
sigma2 = ubuf[1]/float(n2); sigma2 = ubuf[1]/float(n2);
// don't care about final value of count, we know count == n2 // don't care about final value of count, we know count == n2
} else { } else {
mu = WARP_SHFL(mu, 0, 32); mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2/float(n2), 0, 32); sigma2 = WARP_SHFL(sigma2 / float(n2), 0);
} }
} }
} }
...@@ -292,7 +295,8 @@ void cuApplyLayerNorm_( ...@@ -292,7 +295,8 @@ void cuApplyLayerNorm_(
const int n2, const int n2,
const U epsilon, const U epsilon,
const V* __restrict__ gamma, const V* __restrict__ gamma,
const V* __restrict__ beta const V* __restrict__ beta,
const int GPU_WARP_SIZE
) )
{ {
// Assumptions: // Assumptions:
...@@ -303,7 +307,7 @@ void cuApplyLayerNorm_( ...@@ -303,7 +307,7 @@ void cuApplyLayerNorm_(
SharedMemory<U> shared; SharedMemory<U> shared;
U* buf = shared.getPointer(); U* buf = shared.getPointer();
U mu,sigma2; U mu,sigma2;
cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf); cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf, GPU_WARP_SIZE);
const T* lvals = vals + i1*n2; const T* lvals = vals + i1*n2;
V* ovals = output_vals + i1*n2; V* ovals = output_vals + i1*n2;
U c_invvar = rsqrt(sigma2 + epsilon); U c_invvar = rsqrt(sigma2 + epsilon);
...@@ -338,13 +342,12 @@ void cuApplyLayerNorm( ...@@ -338,13 +342,12 @@ void cuApplyLayerNorm(
const int n2, const int n2,
const U epsilon, const U epsilon,
const V* __restrict__ gamma, const V* __restrict__ gamma,
const V* __restrict__ beta const V* __restrict__ beta,
) const int warp_size)
{ {
cuApplyLayerNorm_<T, U, V>(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta); cuApplyLayerNorm_<T, U, V>(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, warp_size);
} }
template<typename T, typename U, typename V> __device__ template<typename T, typename U, typename V> __device__
void cuLoadWriteStridedInputs( void cuLoadWriteStridedInputs(
const int i1_block, const int i1_block,
...@@ -388,7 +391,6 @@ void cuLoadWriteStridedInputs( ...@@ -388,7 +391,6 @@ void cuLoadWriteStridedInputs(
} }
} }
} }
template<typename T, typename U, typename V> __device__ template<typename T, typename U, typename V> __device__
void cuLoadAddStridedInputs( void cuLoadAddStridedInputs(
const int i1_block, const int i1_block,
...@@ -565,6 +567,7 @@ void cuComputeGradInput( ...@@ -565,6 +567,7 @@ void cuComputeGradInput(
const int numx = blockDim.x * blockDim.y; const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x; const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) { if (gamma != NULL) {
#ifndef __HIP_PLATFORM_HCC__
int l = 4*thrx; int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) { for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) { for (int k = 0; k < 4; ++k) {
...@@ -580,7 +583,19 @@ void cuComputeGradInput( ...@@ -580,7 +583,19 @@ void cuComputeGradInput(
sum_loss1 += c_loss * gamma[l]; sum_loss1 += c_loss * gamma[l];
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
} }
#else
// Optimization for ROCm MI100
for( int l = 0; l < n2 ; l += numx) {
int idx = l + thrx;
const U gamma_idx = static_cast<U>((idx<n2) ? gamma[idx] : V(0));
const U c_h = static_cast<U>((idx<n2) ? k_input[idx] : T(0));
const U c_loss = static_cast<U>((idx<n2) ? k_dout[idx] : V(0));
sum_loss1 += c_loss * gamma_idx;
sum_loss2 += c_loss * gamma_idx * (c_h - c_mean) * c_invvar;
}
#endif
} else { } else {
#ifndef __HIP_PLATFORM_HCC__
int l = 4*thrx; int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) { for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) { for (int k = 0; k < 4; ++k) {
...@@ -596,11 +611,20 @@ void cuComputeGradInput( ...@@ -596,11 +611,20 @@ void cuComputeGradInput(
sum_loss1 += c_loss; sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
} }
#else
for( int l = 0; l < n2 ; l += numx) {
int idx = l + thrx;
const U c_h = static_cast<U>((idx<n2) ? k_input[idx] : T(0));
const U c_loss = static_cast<U>((idx<n2) ? k_dout[idx] : V(0));
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
#endif
} }
// intra-warp reductions // intra-warp reductions
for (int mask = blockDim.x/2; mask > 0; mask /= 2) { for (int mask = blockDim.x / 2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask, 32); sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask, 32); sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
} }
// inter-warp reductions // inter-warp reductions
if (blockDim.y > 1) { if (blockDim.y > 1) {
...@@ -676,7 +700,13 @@ void HostApplyLayerNorm( ...@@ -676,7 +700,13 @@ void HostApplyLayerNorm(
) )
{ {
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
const dim3 threads(32,4,1); const int warp_size = at::cuda::getCurrentDeviceProperties()->warpSize;
dim3 threads(warp_size ,4, 1); // MI100 wavefront/warp = 64
#ifdef __HIP_PLATFORM_HCC__
// Optimization for ROCm MI100
threads.y = 1;
#endif
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared = int nshared =
...@@ -684,7 +714,7 @@ void HostApplyLayerNorm( ...@@ -684,7 +714,7 @@ void HostApplyLayerNorm(
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
0; 0;
cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>( cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta); output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta, warp_size);
} }
void cuda_layer_norm( void cuda_layer_norm(
...@@ -736,12 +766,13 @@ void HostLayerNormGradient( ...@@ -736,12 +766,13 @@ void HostLayerNormGradient(
) )
{ {
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
const int warp_size = at::cuda::getCurrentDeviceProperties()->warpSize;
if (gamma != NULL && beta != NULL) { if (gamma != NULL && beta != NULL) {
// compute grad_gamma(j) and grad_beta(j) // compute grad_gamma(j) and grad_beta(j)
const int part_size = 16; const int part_size = warp_size;
const dim3 threads2(32,4,1); const dim3 threads2(warp_size, 4, 1);
const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1); const dim3 blocks2((n2+threads2.x-1) / threads2.x,part_size, 1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U); const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
...@@ -763,7 +794,7 @@ void HostLayerNormGradient( ...@@ -763,7 +794,7 @@ void HostLayerNormGradient(
part_grad_gamma.DATA_PTR<U>(), part_grad_gamma.DATA_PTR<U>(),
part_grad_beta.DATA_PTR<U>()); part_grad_beta.DATA_PTR<U>());
const dim3 threads3(32,8,1); const dim3 threads3(warp_size, 8, 1);
const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
const int nshared3 = threads3.x * threads3.y * sizeof(U); const int nshared3 = threads3.x * threads3.y * sizeof(U);
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>( cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
...@@ -776,9 +807,16 @@ void HostLayerNormGradient( ...@@ -776,9 +807,16 @@ void HostLayerNormGradient(
} }
// compute grad_input // compute grad_input
// https://github.com/microsoft/onnxruntime/pull/7682/files#diff-f9eace25e62b646410b067f96cd930c7fe843326dca1e8d383631ca27f1a8d00R540
// https://github.com/amathews-amd/onnxruntime/blob/80c0555c2bc17fb109190e2082cd3fda0a37984c/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu#L541
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32,4,1); dim3 threads1(warp_size,4,1); // MI100 wavefront/warp = 64
#ifdef __HIP_PLATFORM_HCC__
// Optimization for ROCm MI100
threads1.y = 2;
#endif
int nshared = int nshared =
threads1.y > 1 ? threads1.y > 1 ?
threads1.y*threads1.x*sizeof(U) : threads1.y*threads1.x*sizeof(U) :
...@@ -834,3 +872,4 @@ void cuda_layer_norm_gradient( ...@@ -834,3 +872,4 @@ void cuda_layer_norm_gradient(
gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL); gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
) )
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment