Commit d755f1f1 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Fix some bugs

parent 47921708
...@@ -75,7 +75,7 @@ void cuWelfordMuSigma2( ...@@ -75,7 +75,7 @@ void cuWelfordMuSigma2(
U& mu, U& mu,
U& sigma2, U& sigma2,
U* buf, U* buf,
const int GPU_WARP_SIZE) const int GPU_WARP_SIZE,
bool rms_only) bool rms_only)
{ {
// Assumptions: // Assumptions:
...@@ -185,7 +185,7 @@ void cuWelfordMuSigma2( ...@@ -185,7 +185,7 @@ void cuWelfordMuSigma2(
float& mu, float& mu,
float& sigma2, float& sigma2,
float* buf, float* buf,
const int GPU_WARP_SIZE) const int GPU_WARP_SIZE,
bool rms_only) bool rms_only)
{ {
// Assumptions: // Assumptions:
...@@ -369,9 +369,8 @@ void cuApplyLayerNorm_( ...@@ -369,9 +369,8 @@ void cuApplyLayerNorm_(
const U epsilon, const U epsilon,
const V* __restrict__ gamma, const V* __restrict__ gamma,
const V* __restrict__ beta, const V* __restrict__ beta,
const int GPU_WARP_SIZE const int GPU_WARP_SIZE,
bool rms_only bool rms_only)
)
{ {
// Assumptions: // Assumptions:
// 1) blockDim.x == warpSize // 1) blockDim.x == warpSize
...@@ -433,6 +432,20 @@ void cuApplyLayerNorm( ...@@ -433,6 +432,20 @@ void cuApplyLayerNorm(
cuApplyLayerNorm_<T, U, V>(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, warp_size, false); cuApplyLayerNorm_<T, U, V>(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, warp_size, false);
} }
template<typename T, typename U, typename V=T> __global__
void cuApplyRMSNorm(
V* __restrict__ output_vals,
U* __restrict__ invvar,
const T* __restrict__ vals,
const int n1,
const int n2,
const U epsilon,
const V* __restrict__ gamma,
const int warp_size)
{
cuApplyLayerNorm_<T, U, V>(output_vals, NULL, invvar, vals, n1, n2, epsilon, gamma, NULL, warp_size, true);
}
template<typename T, typename U, typename V> __device__ template<typename T, typename U, typename V> __device__
void cuLoadWriteStridedInputs( void cuLoadWriteStridedInputs(
const int i1_block, const int i1_block,
...@@ -882,6 +895,7 @@ void HostApplyLayerNorm( ...@@ -882,6 +895,7 @@ void HostApplyLayerNorm(
output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta, warp_size); output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta, warp_size);
} }
// TODO: Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
template<typename T, typename U, typename V=T> template<typename T, typename U, typename V=T>
void HostApplyRMSNorm( void HostApplyRMSNorm(
V* output, V* output,
...@@ -893,6 +907,7 @@ void HostApplyRMSNorm( ...@@ -893,6 +907,7 @@ void HostApplyRMSNorm(
const V* gamma) const V* gamma)
{ {
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
const int warp_size = at::cuda::getCurrentDeviceProperties()->warpSize;
const dim3 threads(32,4,1); const dim3 threads(32,4,1);
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
...@@ -901,7 +916,7 @@ void HostApplyRMSNorm( ...@@ -901,7 +916,7 @@ void HostApplyRMSNorm(
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
0; 0;
cuApplyRMSNorm<<<blocks, threads, nshared, stream>>>( cuApplyRMSNorm<<<blocks, threads, nshared, stream>>>(
output, invvar, input, n1, n2, U(epsilon), gamma); output, invvar, input, n1, n2, U(epsilon), gamma, warp_size);
} }
void cuda_layer_norm( void cuda_layer_norm(
...@@ -1200,3 +1215,4 @@ void cuda_rms_norm_gradient( ...@@ -1200,3 +1215,4 @@ void cuda_rms_norm_gradient(
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL); gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL);
) )
} }
import unittest import unittest
import os import os
import random import random
import itertools
import torch import torch
import apex import apex
from torch.autograd import Variable from torch.autograd import Variable
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment