Unverified Commit 1cb3da87 authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

Optimize layer normalization for AMD GPUs (#66)

* Optimize fused layer normalization for MI100

* Optimize cuComputePartGradGammaBeta for AMD GPUs
parent 151d150b
......@@ -8,6 +8,7 @@
#include "type_shim.h"
template<typename U> __device__
void cuWelfordOnlineSum(
const U curr,
......@@ -56,7 +57,8 @@ void cuWelfordMuSigma2(
const int i1,
U& mu,
U& sigma2,
U* buf)
U* buf,
const int GPU_WARP_SIZE)
{
// Assumptions:
// 1) blockDim.x == warpSize
......@@ -86,12 +88,12 @@ void cuWelfordMuSigma2(
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31;
U muB = WARP_SHFL(mu, srcLaneB, 32);
U countB = WARP_SHFL(count, srcLaneB, 32);
U sigma2B = WARP_SHFL(sigma2, srcLaneB, 32);
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
#pragma unroll
for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) {
U muB = WARP_SHFL_DOWN(mu, stride);
U countB = WARP_SHFL_DOWN(count, stride);
U sigma2B = WARP_SHFL_DOWN(sigma2, stride);
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
......@@ -126,8 +128,8 @@ void cuWelfordMuSigma2(
sigma2 = ubuf[1]/U(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0, 32);
sigma2 = WARP_SHFL(sigma2/U(n2), 0, 32);
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2 / U(n2), 0);
}
}
}
......@@ -140,7 +142,8 @@ void cuWelfordMuSigma2(
const int i1,
float& mu,
float& sigma2,
float* buf)
float* buf,
const int GPU_WARP_SIZE)
{
// Assumptions:
// 1) blockDim.x == warpSize
......@@ -181,12 +184,12 @@ void cuWelfordMuSigma2(
cuWelfordOnlineSum(curr,mu,sigma2,count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31;
float muB = WARP_SHFL(mu, srcLaneB, 32);
float countB = WARP_SHFL(count, srcLaneB, 32);
float sigma2B = WARP_SHFL(sigma2, srcLaneB, 32);
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
#pragma unroll
for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) { // TODO
float muB = WARP_SHFL_DOWN(mu, stride);
float countB = WARP_SHFL_DOWN(count, stride);
float sigma2B = WARP_SHFL_DOWN(sigma2, stride);
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
......@@ -221,8 +224,8 @@ void cuWelfordMuSigma2(
sigma2 = ubuf[1]/float(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0, 32);
sigma2 = WARP_SHFL(sigma2/float(n2), 0, 32);
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2 / float(n2), 0);
}
}
}
......@@ -292,7 +295,8 @@ void cuApplyLayerNorm_(
const int n2,
const U epsilon,
const V* __restrict__ gamma,
const V* __restrict__ beta
const V* __restrict__ beta,
const int GPU_WARP_SIZE
)
{
// Assumptions:
......@@ -303,7 +307,7 @@ void cuApplyLayerNorm_(
SharedMemory<U> shared;
U* buf = shared.getPointer();
U mu,sigma2;
cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf);
cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf, GPU_WARP_SIZE);
const T* lvals = vals + i1*n2;
V* ovals = output_vals + i1*n2;
U c_invvar = rsqrt(sigma2 + epsilon);
......@@ -338,13 +342,12 @@ void cuApplyLayerNorm(
const int n2,
const U epsilon,
const V* __restrict__ gamma,
const V* __restrict__ beta
)
const V* __restrict__ beta,
const int warp_size)
{
cuApplyLayerNorm_<T, U, V>(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta);
cuApplyLayerNorm_<T, U, V>(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, warp_size);
}
template<typename T, typename U, typename V> __device__
void cuLoadWriteStridedInputs(
const int i1_block,
......@@ -388,7 +391,6 @@ void cuLoadWriteStridedInputs(
}
}
}
template<typename T, typename U, typename V> __device__
void cuLoadAddStridedInputs(
const int i1_block,
......@@ -565,6 +567,7 @@ void cuComputeGradInput(
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) {
#ifndef __HIP_PLATFORM_HCC__
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) {
......@@ -580,7 +583,19 @@ void cuComputeGradInput(
sum_loss1 += c_loss * gamma[l];
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
}
#else
// Optimization for ROCm MI100
for( int l = 0; l < n2 ; l += numx) {
int idx = l + thrx;
const U gamma_idx = static_cast<U>((idx<n2) ? gamma[idx] : V(0));
const U c_h = static_cast<U>((idx<n2) ? k_input[idx] : T(0));
const U c_loss = static_cast<U>((idx<n2) ? k_dout[idx] : V(0));
sum_loss1 += c_loss * gamma_idx;
sum_loss2 += c_loss * gamma_idx * (c_h - c_mean) * c_invvar;
}
#endif
} else {
#ifndef __HIP_PLATFORM_HCC__
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) {
......@@ -596,11 +611,20 @@ void cuComputeGradInput(
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
#else
for( int l = 0; l < n2 ; l += numx) {
int idx = l + thrx;
const U c_h = static_cast<U>((idx<n2) ? k_input[idx] : T(0));
const U c_loss = static_cast<U>((idx<n2) ? k_dout[idx] : V(0));
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
#endif
}
// intra-warp reductions
for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask, 32);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask, 32);
for (int mask = blockDim.x / 2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
}
// inter-warp reductions
if (blockDim.y > 1) {
......@@ -676,7 +700,13 @@ void HostApplyLayerNorm(
)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
const dim3 threads(32,4,1);
const int warp_size = at::cuda::getCurrentDeviceProperties()->warpSize;
dim3 threads(warp_size ,4, 1); // MI100 wavefront/warp = 64
#ifdef __HIP_PLATFORM_HCC__
// Optimization for ROCm MI100
threads.y = 1;
#endif
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared =
......@@ -684,7 +714,7 @@ void HostApplyLayerNorm(
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
0;
cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta);
output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta, warp_size);
}
void cuda_layer_norm(
......@@ -736,12 +766,13 @@ void HostLayerNormGradient(
)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int warp_size = at::cuda::getCurrentDeviceProperties()->warpSize;
if (gamma != NULL && beta != NULL) {
// compute grad_gamma(j) and grad_beta(j)
const int part_size = 16;
const dim3 threads2(32,4,1);
const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
const int part_size = warp_size;
const dim3 threads2(warp_size, 4, 1);
const dim3 blocks2((n2+threads2.x-1) / threads2.x,part_size, 1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
......@@ -763,7 +794,7 @@ void HostLayerNormGradient(
part_grad_gamma.DATA_PTR<U>(),
part_grad_beta.DATA_PTR<U>());
const dim3 threads3(32,8,1);
const dim3 threads3(warp_size, 8, 1);
const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
const int nshared3 = threads3.x * threads3.y * sizeof(U);
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
......@@ -776,9 +807,16 @@ void HostLayerNormGradient(
}
// compute grad_input
// https://github.com/microsoft/onnxruntime/pull/7682/files#diff-f9eace25e62b646410b067f96cd930c7fe843326dca1e8d383631ca27f1a8d00R540
// https://github.com/amathews-amd/onnxruntime/blob/80c0555c2bc17fb109190e2082cd3fda0a37984c/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu#L541
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32,4,1);
dim3 threads1(warp_size,4,1); // MI100 wavefront/warp = 64
#ifdef __HIP_PLATFORM_HCC__
// Optimization for ROCm MI100
threads1.y = 2;
#endif
int nshared =
threads1.y > 1 ?
threads1.y*threads1.x*sizeof(U) :
......@@ -834,3 +872,4 @@ void cuda_layer_norm_gradient(
gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment