Commit 28c5638d authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Optimize HostRMSNormGradient and HostApplyRMSNorm for AMD GPUs

parent d755f1f1
......@@ -908,9 +908,13 @@ void HostApplyRMSNorm(
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int warp_size = at::cuda::getCurrentDeviceProperties()->warpSize;
const dim3 threads(32,4,1);
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
dim3 threads(warp_size,4,1);
#ifdef __HIP_PLATFORM_HCC__
// Optimization for ROCm MI100
threads.y = 2;
#endif
int nshared =
threads.y > 1 ?
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
......@@ -1080,10 +1084,10 @@ void HostRMSNormGradient(
V* grad_gamma)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int warp_size = at::cuda::getCurrentDeviceProperties()->warpSize;
if (gamma != NULL) {
const int part_size = 16;
const dim3 threads2(32,4,1);
const int part_size = warp_size;
const dim3 threads2(warp_size,4,1);
const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
......@@ -1106,7 +1110,7 @@ void HostRMSNormGradient(
part_grad_gamma.DATA_PTR<U>(), /* unused */
true);
const dim3 threads3(32,8,1);
const dim3 threads3(warp_size,8,1);
const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
const int nshared3 = threads3.x * threads3.y * sizeof(U);
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
......@@ -1122,7 +1126,7 @@ void HostRMSNormGradient(
// compute grad_input
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32,4,1);
const dim3 threads1(warp_size,4,1);
int nshared =
threads1.y > 1 ?
threads1.y*threads1.x*sizeof(U) :
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment