"docs/vscode:/vscode.git/clone" did not exist on "32b85dfa8d4a5fa54469ddc72be89d827c1ee9d6"
Unverified Commit 2d0f9cf2 authored by Chaitanya Sri Krishna Lolla's avatar Chaitanya Sri Krishna Lolla Committed by GitHub
Browse files

Enable fusedlayernorm extension (#3)

parent 3ccdd63d
......@@ -172,8 +172,8 @@ void cuWelfordMuSigma2(
for (; l+7 < n2; l+=8*numx) {
for (int k = 0; k < 8; k+=2) {
float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
cuWelfordOnlineSum(curr.x,mu,sigma2,count);
cuWelfordOnlineSum(curr.y,mu,sigma2,count);
cuWelfordOnlineSum<float>(curr.x,mu,sigma2,count);
cuWelfordOnlineSum<float>(curr.y,mu,sigma2,count);
}
}
for (; l < n2; ++l) {
......@@ -230,9 +230,15 @@ void cuWelfordMuSigma2(
template<typename U> U rsqrt(U v) {
return U(1) / sqrt(v);
}
#if defined __HIP_PLATFORM_HCC__
__device__ float rsqrt(float v) {
return rsqrtf(v);
}
#else
template<> float rsqrt(float v) {
return rsqrtf(v);
}
#endif
template<> double rsqrt(double v) {
return rsqrt(v);
}
......@@ -293,7 +299,7 @@ void cuApplyLayerNorm(
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
U mu,sigma2;
......@@ -531,7 +537,7 @@ void cuComputeGradInput(
const T* gamma,
T* grad_input)
{
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
......
......@@ -177,7 +177,13 @@ if "--cuda_ext" in sys.argv:
'-O3',
'--use_fast_math'] + version_dependent_macros}))
else:
print ("INFO: Skipping FusedLayerNorm extension.")
print ("INFO: Building FusedLayerNorm extension.")
ext_modules.append(
CUDAExtension(name='fused_layer_norm_cuda',
sources=['csrc/layer_norm_cuda.cpp',
'csrc/hip/layer_norm_hip_kernel.hip'],
extra_compile_args={'cxx' : ['-O3'] + version_dependent_macros,
'nvcc' : []}))
if not is_rocm_pytorch:
ext_modules.append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment