Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
d650e6a2
Commit
d650e6a2
authored
Apr 19, 2023
by
Jon Barker
Committed by
Jared Casper
Apr 19, 2023
Browse files
replace custom layer_norm_cuda with Apex layer_norm_cuda
parent
8dbd0757
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
10 additions
and
1273 deletions
+10
-1273
megatron/fused_kernels/__init__.py
megatron/fused_kernels/__init__.py
+1
-24
megatron/fused_kernels/fused_weight_gradient_dense.cpp
megatron/fused_kernels/fused_weight_gradient_dense.cpp
+0
-47
megatron/fused_kernels/fused_weight_gradient_dense.cu
megatron/fused_kernels/fused_weight_gradient_dense.cu
+0
-157
megatron/fused_kernels/layer_norm_cuda.cpp
megatron/fused_kernels/layer_norm_cuda.cpp
+0
-187
megatron/fused_kernels/layer_norm_cuda_kernel.cu
megatron/fused_kernels/layer_norm_cuda_kernel.cu
+0
-818
megatron/fused_kernels/tests/test_fused_kernels.py
megatron/fused_kernels/tests/test_fused_kernels.py
+3
-4
megatron/model/fused_layer_norm.py
megatron/model/fused_layer_norm.py
+6
-36
No files found.
megatron/fused_kernels/__init__.py
View file @
d650e6a2
...
@@ -54,7 +54,7 @@ def load(args):
...
@@ -54,7 +54,7 @@ def load(args):
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
]
'--expt-extended-lambda'
]
# Upper triangular softmax.
# Upper triangular softmax.
sources
=
[
srcpath
/
'scaled_upper_triang_masked_softmax.cpp'
,
sources
=
[
srcpath
/
'scaled_upper_triang_masked_softmax.cpp'
,
srcpath
/
'scaled_upper_triang_masked_softmax_cuda.cu'
]
srcpath
/
'scaled_upper_triang_masked_softmax_cuda.cu'
]
...
@@ -74,29 +74,6 @@ def load(args):
...
@@ -74,29 +74,6 @@ def load(args):
scaled_softmax_cuda
=
_cpp_extention_load_helper
(
scaled_softmax_cuda
=
_cpp_extention_load_helper
(
"scaled_softmax_cuda"
,
sources
,
extra_cuda_flags
)
"scaled_softmax_cuda"
,
sources
,
extra_cuda_flags
)
# =================================
# Mixed precision fused layer norm.
# =================================
extra_hopper_flags
=
[
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
]
extra_cuda_flags
=
[
'-maxrregcount=50'
]
sources
=
[
srcpath
/
'layer_norm_cuda.cpp'
,
srcpath
/
'layer_norm_cuda_kernel.cu'
]
fused_mix_prec_layer_norm_cuda
=
_cpp_extention_load_helper
(
"fused_mix_prec_layer_norm_cuda"
,
sources
,
extra_cuda_flags
+
extra_hopper_flags
)
# =================================
# Fused gradient accumulation to weight gradient computation of linear layer
# =================================
if
args
.
gradient_accumulation_fusion
:
sources
=
[
srcpath
/
'fused_weight_gradient_dense.cpp'
,
srcpath
/
'fused_weight_gradient_dense.cu'
]
fused_dense_cuda
=
_cpp_extention_load_helper
(
"fused_dense_cuda"
,
sources
,
extra_hopper_flags
)
def
_get_cuda_bare_metal_version
(
cuda_dir
):
def
_get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
...
...
megatron/fused_kernels/fused_weight_gradient_dense.cpp
deleted
100644 → 0
View file @
8dbd0757
#include <torch/torch.h>
#include <torch/extension.h>
#include <vector>
#include <stdio.h>
#include "type_shim.h"
template
<
typename
T
>
int
wgrad_gemm_accum_fp32_cuda
(
T
*
input
,
T
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
void
wgrad_gemm_accum_fp32
(
const
at
::
Tensor
input
,
const
at
::
Tensor
d_output
,
at
::
Tensor
d_weight
)
{
at
::
Tensor
input_2d
,
d_output_2d
;
// input tensor: collapse to the first dim
auto
in_sizes
=
input
.
sizes
();
if
(
input
.
dim
()
>
2
)
{
input_2d
=
input
.
view
({
-
1
,
in_sizes
[
in_sizes
.
size
()
-
1
]});
}
else
{
input_2d
=
input
;
}
// d_output tensor: collapse to the first dim
auto
d_out_sizes
=
d_output
.
sizes
();
if
(
d_output
.
dim
()
>
2
)
{
d_output_2d
=
d_output
.
view
({
-
1
,
d_out_sizes
[
d_out_sizes
.
size
()
-
1
]});
}
else
{
d_output_2d
=
d_output
;
}
int
hidden_dim
=
input_2d
.
size
(
0
);
int
in_dim
=
input_2d
.
size
(
1
);
int
out_dim
=
d_weight
.
size
(
0
);
DISPATCH_HALF_BFLOAT_AND_FLOAT
(
input_2d
.
scalar_type
(),
"wgrad_gemm_accum_fp32"
,
int
result
=
wgrad_gemm_accum_fp32_cuda
<
scalar_t
>
(
input_2d
.
data_ptr
<
scalar_t
>
(),
d_output_2d
.
data_ptr
<
scalar_t
>
(),
d_weight
.
data_ptr
<
float
>
(),
in_dim
,
hidden_dim
,
out_dim
);
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"wgrad_gemm_accum_fp32"
,
&
wgrad_gemm_accum_fp32
,
"wgrad gemm accum in fp32"
);
}
megatron/fused_kernels/fused_weight_gradient_dense.cu
deleted
100644 → 0
View file @
8dbd0757
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <torch/torch.h>
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
// BF16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t
gemmex_wrapper
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
at
::
BFloat16
*
A
,
int
lda
,
at
::
BFloat16
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16BF
,
lda
,
B
,
CUDA_R_16BF
,
ldb
,
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
// FP16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t
gemmex_wrapper
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
at
::
Half
*
A
,
int
lda
,
at
::
Half
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16F
,
lda
,
B
,
CUDA_R_16F
,
ldb
,
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
// FP32 Tensor core wrapper around cublas GEMMEx
cublasStatus_t
gemmex_wrapper
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
float
*
A
,
int
lda
,
float
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_32F
,
lda
,
B
,
CUDA_R_32F
,
ldb
,
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
template
<
typename
T
>
int
wgrad_gemm_accum_fp32_cuda
(
T
*
input
,
T
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
)
{
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
;
cublasGetStream
(
handle
,
&
stream
);
const
float
alpha
=
1.0
;
const
float
beta
=
1.0
;
int
status
=
1
;
status
=
gemmex_wrapper
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
in_dim
,
out_dim
,
hidden_dim
,
&
alpha
,
input
,
in_dim
,
d_output
,
out_dim
,
&
beta
,
d_weight
,
in_dim
);
return
status
;
}
template
int
wgrad_gemm_accum_fp32_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
template
int
wgrad_gemm_accum_fp32_cuda
<
at
::
BFloat16
>(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
template
int
wgrad_gemm_accum_fp32_cuda
<
float
>(
float
*
input
,
float
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
megatron/fused_kernels/layer_norm_cuda.cpp
deleted
100644 → 0
View file @
8dbd0757
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <torch/extension.h>
#include <vector>
#include <cassert>
#include "compat.h"
namespace
{
void
compute_n1_n2
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
int
&
n1
,
int
&
n2
)
{
int
idiff
=
input
.
ndimension
()
-
normalized_shape
.
size
();
n2
=
1
;
for
(
int
i
=
0
;
i
<
(
int
)
normalized_shape
.
size
();
++
i
)
{
assert
(
input
.
sizes
()[
i
+
idiff
]
==
normalized_shape
[
i
]
);
n2
*=
normalized_shape
[
i
];
}
n1
=
1
;
for
(
int
i
=
0
;
i
<
idiff
;
++
i
)
{
n1
*=
input
.
sizes
()[
i
];
}
}
void
check_args
(
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
)
{
TORCH_CHECK
(
!
gamma
.
defined
()
||
gamma
.
sizes
().
equals
(
normalized_shape
));
TORCH_CHECK
(
!
beta
.
defined
()
||
beta
.
sizes
().
equals
(
normalized_shape
));
}
void
check_args
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
int
&
n1
,
int
&
n2
)
{
int64_t
normalized_ndim
=
normalized_shape
.
size
();
if
(
normalized_ndim
<
1
)
{
std
::
stringstream
ss
;
ss
<<
"Expected normalized_shape to be at least 1-dimensional, i.e., "
<<
"containing at least one element, but got normalized_shape="
<<
normalized_shape
;
throw
std
::
runtime_error
(
ss
.
str
());
}
auto
input_shape
=
input
.
sizes
();
auto
input_ndim
=
input
.
dim
();
if
(
input_ndim
<
normalized_ndim
||
!
input_shape
.
slice
(
input_ndim
-
normalized_ndim
).
equals
(
normalized_shape
))
{
std
::
stringstream
ss
;
ss
<<
"Given normalized_shape="
<<
normalized_shape
<<
", expected input with shape [*"
;
for
(
auto
size
:
normalized_shape
)
{
ss
<<
", "
<<
size
;
}
ss
<<
"], but got input of size"
<<
input_shape
;
throw
std
::
runtime_error
(
ss
.
str
());
}
compute_n1_n2
(
input
,
normalized_shape
,
n1
,
n2
);
}
void
check_args
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
int
&
n1
,
int
&
n2
)
{
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
check_args
(
normalized_shape
,
gamma
,
beta
);
}
}
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
);
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
at
::
Tensor
>
layer_norm_affine
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
double
epsilon
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
,
gamma
.
options
().
dtype
(
gamma
.
scalar_type
()));
at
::
Tensor
mean
=
at
::
empty
(
{
n1
},
input
.
options
().
dtype
(
at
::
ScalarType
::
Float
));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
);
return
{
output
,
mean
,
invvar
};
}
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
,
at
::
Tensor
*
grad_input
,
at
::
Tensor
*
grad_gamma
,
at
::
Tensor
*
grad_beta
);
std
::
vector
<
at
::
Tensor
>
layer_norm_gradient_affine
(
at
::
Tensor
dout
,
at
::
Tensor
mean
,
at
::
Tensor
invvar
,
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
double
epsilon
)
{
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
invvar
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
at
::
Tensor
grad_gamma
=
at
::
empty_like
(
gamma
);
at
::
Tensor
grad_beta
=
at
::
empty_like
(
beta
);
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
,
&
grad_input
,
&
grad_gamma
,
&
grad_beta
);
return
{
grad_input
,
grad_gamma
,
grad_beta
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_affine"
,
&
layer_norm_affine
,
"LayerNorm forward (CUDA)"
);
m
.
def
(
"backward_affine"
,
&
layer_norm_gradient_affine
,
"LayerNorm backward (CUDA)"
);
}
megatron/fused_kernels/layer_norm_cuda_kernel.cu
deleted
100644 → 0
View file @
8dbd0757
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/DeviceUtils.cuh"
#include <cuda.h>
#include <cuda_runtime.h>
#include "type_shim.h"
template
<
typename
U
>
__device__
void
cuWelfordOnlineSum
(
const
U
curr
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
count
=
count
+
U
(
1
);
U
delta
=
curr
-
mu
;
U
lmean
=
mu
+
delta
/
count
;
mu
=
lmean
;
U
delta2
=
curr
-
lmean
;
sigma2
=
sigma2
+
delta
*
delta2
;
}
template
<
typename
U
>
__device__
void
cuChanOnlineSum
(
const
U
muB
,
const
U
sigma2B
,
const
U
countB
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
U
delta
=
muB
-
mu
;
U
nA
=
count
;
U
nB
=
countB
;
count
=
count
+
countB
;
U
nX
=
count
;
if
(
nX
>
U
(
0
))
{
nA
=
nA
/
nX
;
nB
=
nB
/
nX
;
mu
=
nA
*
mu
+
nB
*
muB
;
sigma2
=
sigma2
+
sigma2B
+
delta
*
delta
*
nA
*
nB
*
nX
;
}
else
{
mu
=
U
(
0
);
sigma2
=
U
(
0
);
}
}
template
<
typename
T
,
typename
U
>
__device__
void
cuWelfordMuSigma2
(
const
T
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
int
i1
,
U
&
mu
,
U
&
sigma2
,
U
*
buf
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
U
count
=
U
(
0
);
mu
=
U
(
0
);
sigma2
=
U
(
0
);
if
(
i1
<
n1
)
{
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
T
*
lvals
=
vals
+
i1
*
n2
;
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
l
+
k
]);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
}
}
for
(;
l
<
n2
;
++
l
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
l
]);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
}
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
U
muB
=
WARP_SHFL
(
mu
,
srcLaneB
);
U
countB
=
WARP_SHFL
(
count
,
srcLaneB
);
U
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
);
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
U
*
ubuf
=
(
U
*
)
buf
;
U
*
ibuf
=
(
U
*
)(
ubuf
+
blockDim
.
y
);
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ibuf
[
wrt_y
]
=
count
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
U
muB
=
ubuf
[
2
*
threadIdx
.
y
];
U
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
U
countB
=
ibuf
[
threadIdx
.
y
];
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
__syncthreads
();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
ubuf
[
0
]
=
mu
;
ubuf
[
1
]
=
sigma2
;
}
__syncthreads
();
mu
=
ubuf
[
0
];
sigma2
=
ubuf
[
1
]
/
U
(
n2
);
// don't care about final value of count, we know count == n2
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
);
sigma2
=
WARP_SHFL
(
sigma2
/
U
(
n2
),
0
);
}
}
}
template
<
>
__device__
void
cuWelfordMuSigma2
(
const
at
::
Half
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
int
i1
,
float
&
mu
,
float
&
sigma2
,
float
*
buf
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
float
count
=
0.0
f
;
mu
=
float
(
0
);
sigma2
=
float
(
0
);
if
(
i1
<
n1
)
{
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
at
::
Half
*
lvals
=
vals
+
i1
*
n2
;
int
l
=
8
*
thrx
;
if
((((
size_t
)
lvals
)
&
3
)
!=
0
)
{
// 16 bit alignment
// first thread consumes first point
if
(
thrx
==
0
)
{
float
curr
=
static_cast
<
float
>
(
lvals
[
0
]);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
}
++
l
;
}
// at this point, lvals[l] are 32 bit aligned for all threads.
for
(;
l
+
7
<
n2
;
l
+=
8
*
numx
)
{
for
(
int
k
=
0
;
k
<
8
;
k
+=
2
)
{
float2
curr
=
__half22float2
(
*
((
__half2
*
)(
lvals
+
l
+
k
)));
cuWelfordOnlineSum
(
curr
.
x
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
.
y
,
mu
,
sigma2
,
count
);
}
}
for
(;
l
<
n2
;
++
l
)
{
float
curr
=
static_cast
<
float
>
(
lvals
[
l
]);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
}
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
float
muB
=
WARP_SHFL
(
mu
,
srcLaneB
);
float
countB
=
WARP_SHFL
(
count
,
srcLaneB
);
float
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
);
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
float
*
ubuf
=
(
float
*
)
buf
;
float
*
ibuf
=
(
float
*
)(
ubuf
+
blockDim
.
y
);
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ibuf
[
wrt_y
]
=
count
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
float
muB
=
ubuf
[
2
*
threadIdx
.
y
];
float
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
float
countB
=
ibuf
[
threadIdx
.
y
];
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
__syncthreads
();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
ubuf
[
0
]
=
mu
;
ubuf
[
1
]
=
sigma2
;
}
__syncthreads
();
mu
=
ubuf
[
0
];
sigma2
=
ubuf
[
1
]
/
float
(
n2
);
// don't care about final value of count, we know count == n2
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
);
sigma2
=
WARP_SHFL
(
sigma2
/
float
(
n2
),
0
);
}
}
}
template
<
typename
U
>
U
rsqrt
(
U
v
)
{
return
U
(
1
)
/
sqrt
(
v
);
}
template
<
>
float
rsqrt
(
float
v
)
{
return
rsqrtf
(
v
);
}
template
<
>
double
rsqrt
(
double
v
)
{
return
rsqrt
(
v
);
}
namespace
{
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
// template <typename T>
// struct SharedMemory
// {
// // Ensure that we won't compile any un-specialized types
// __device__ T *getPointer()
// {
// extern __device__ void error(void);
// error();
// return NULL;
// }
// };
// https://github.com/NVIDIA/apex/issues/246
template
<
typename
T
>
struct
SharedMemory
;
template
<
>
struct
SharedMemory
<
float
>
{
__device__
float
*
getPointer
()
{
extern
__shared__
float
s_float
[];
return
s_float
;
}
};
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
void
cuApplyLayerNorm
(
V
*
__restrict__
output_vals
,
U
*
__restrict__
mean
,
U
*
__restrict__
invvar
,
const
T
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
U
epsilon
,
const
V
*
__restrict__
gamma
,
const
V
*
__restrict__
beta
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
U
mu
,
sigma2
;
cuWelfordMuSigma2
(
vals
,
n1
,
n2
,
i1
,
mu
,
sigma2
,
buf
);
const
T
*
lvals
=
vals
+
i1
*
n2
;
V
*
ovals
=
output_vals
+
i1
*
n2
;
U
c_invvar
=
rsqrt
(
sigma2
+
epsilon
);
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
ovals
[
i
]
=
gamma
[
i
]
*
static_cast
<
V
>
(
c_invvar
*
(
curr
-
mu
))
+
beta
[
i
];
}
}
else
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
ovals
[
i
]
=
static_cast
<
V
>
(
c_invvar
*
(
curr
-
mu
));
}
}
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
mean
[
i1
]
=
mu
;
invvar
[
i1
]
=
c_invvar
;
}
__syncthreads
();
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__device__
void
cuLoadWriteStridedInputs
(
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_col_off
,
const
int
i2_off
,
const
int
row_stride
,
U
*
warp_buf1
,
U
*
warp_buf2
,
const
T
*
input
,
const
V
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
<
i1_end
)
{
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
i2
=
i2_off
+
k
;
int
load_idx
=
i1
*
n2
+
i2
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
=
curr_dout
;
warp_buf2
[
write_idx
]
=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
else
{
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
}
}
}
else
{
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__device__
void
cuLoadAddStridedInputs
(
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_col_off
,
const
int
i2_off
,
const
int
row_stride
,
U
*
warp_buf1
,
U
*
warp_buf2
,
const
T
*
input
,
const
V
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
<
i1_end
)
{
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
i2
=
i2_off
+
k
;
int
load_idx
=
i1
*
n2
+
i2
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
+=
curr_dout
;
warp_buf2
[
write_idx
]
+=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
void
cuComputePartGradGammaBeta
(
const
V
*
__restrict__
dout
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
U
*
part_grad_gamma
,
U
*
part_grad_beta
)
{
const
int
numsegs_n1
=
(
n1
+
blockDim
.
y
*
blockDim
.
y
-
1
)
/
(
blockDim
.
y
*
blockDim
.
y
);
const
int
segs_per_block
=
(
numsegs_n1
+
gridDim
.
y
-
1
)
/
gridDim
.
y
;
const
int
i1_beg
=
blockIdx
.
y
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
i1_beg_plus_one
=
(
blockIdx
.
y
+
1
)
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
i1_end
=
i1_beg_plus_one
<
n1
?
i1_beg_plus_one
:
n1
;
const
int
row_stride
=
blockDim
.
x
+
1
;
const
int
thr_load_col_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
&
(
blockDim
.
x
-
1
);
const
int
thr_load_row_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
/
blockDim
.
x
+
threadIdx
.
y
*
blockDim
.
y
;
const
int
i2_off
=
blockIdx
.
x
*
blockDim
.
x
+
thr_load_col_off
;
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
// buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
U
*
warp_buf1
=
(
U
*
)
buf
;
U
*
warp_buf2
=
warp_buf1
+
blockDim
.
y
*
blockDim
.
y
*
row_stride
;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs
(
i1_beg
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
for
(
int
i1_block
=
i1_beg
+
blockDim
.
y
*
blockDim
.
y
;
i1_block
<
i1_end
;
i1_block
+=
blockDim
.
y
*
blockDim
.
y
)
{
cuLoadAddStridedInputs
(
i1_block
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
}
__syncthreads
();
// inter-warp reductions
// sum within each warp
U
acc1
=
U
(
0
);
U
acc2
=
U
(
0
);
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
row1
=
threadIdx
.
y
+
k
*
blockDim
.
y
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
acc1
+=
warp_buf1
[
idx1
];
acc2
+=
warp_buf2
[
idx1
];
}
warp_buf1
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc1
;
warp_buf2
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc2
;
__syncthreads
();
// sum all warps
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
1
;
offset
/=
2
)
{
if
(
threadIdx
.
y
<
offset
)
{
int
row1
=
threadIdx
.
y
;
int
row2
=
threadIdx
.
y
+
offset
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
warp_buf1
[
idx1
]
+=
warp_buf1
[
idx2
];
warp_buf2
[
idx1
]
+=
warp_buf2
[
idx2
];
}
__syncthreads
();
}
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
threadIdx
.
y
==
0
&&
i2
<
n2
)
{
int
row1
=
threadIdx
.
y
;
int
row2
=
threadIdx
.
y
+
1
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
part_grad_beta
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf1
[
idx1
]
+
warp_buf1
[
idx2
];
part_grad_gamma
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf2
[
idx1
]
+
warp_buf2
[
idx2
];
}
}
template
<
typename
U
,
typename
V
>
__global__
void
cuComputeGradGammaBeta
(
const
U
*
part_grad_gamma
,
const
U
*
part_grad_beta
,
const
int
part_size
,
const
int
n1
,
const
int
n2
,
V
*
grad_gamma
,
V
*
grad_beta
)
{
// sum partial gradients for gamma and beta
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i2
<
n2
)
{
// each warp does sequential reductions until reduced part_size is num_warps
int
num_warp_reductions
=
part_size
/
blockDim
.
y
;
U
sum_gamma
=
U
(
0
);
U
sum_beta
=
U
(
0
);
const
U
*
part_grad_gamma_ptr
=
part_grad_gamma
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
const
U
*
part_grad_beta_ptr
=
part_grad_beta
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
for
(
int
warp_offset
=
0
;
warp_offset
<
num_warp_reductions
;
++
warp_offset
)
{
sum_gamma
+=
part_grad_gamma_ptr
[
warp_offset
*
n2
];
sum_beta
+=
part_grad_beta_ptr
[
warp_offset
*
n2
];
}
// inter-warp reductions
const
int
nbsize3
=
blockDim
.
x
*
blockDim
.
y
/
2
;
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>=
1
;
offset
/=
2
)
{
// top half write to shared memory
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
write_idx
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
write_idx
]
=
sum_gamma
;
buf
[
write_idx
+
nbsize3
]
=
sum_beta
;
}
__syncthreads
();
// bottom half sums
if
(
threadIdx
.
y
<
offset
)
{
const
int
read_idx
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_gamma
+=
buf
[
read_idx
];
sum_beta
+=
buf
[
read_idx
+
nbsize3
];
}
__syncthreads
();
}
// write out fully summed gradients
if
(
threadIdx
.
y
==
0
)
{
grad_gamma
[
i2
]
=
sum_gamma
;
grad_beta
[
i2
]
=
sum_beta
;
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
void
cuComputeGradInput
(
const
V
*
__restrict__
dout
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
const
V
*
gamma
,
T
*
grad_input
)
{
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
U
sum_loss1
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
const
U
c_mean
=
mean
[
i1
];
const
U
c_invvar
=
invvar
[
i1
];
const
T
*
k_input
=
input
+
i1
*
n2
;
const
V
*
k_dout
=
dout
+
i1
*
n2
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
)
{
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
sum_loss1
+=
c_loss
*
gamma
[
l
+
k
];
sum_loss2
+=
c_loss
*
gamma
[
l
+
k
]
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
for
(;
l
<
n2
;
++
l
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
sum_loss1
+=
c_loss
*
gamma
[
l
];
sum_loss2
+=
c_loss
*
gamma
[
l
]
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
else
{
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
sum_loss1
+=
c_loss
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
for
(;
l
<
n2
;
++
l
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
sum_loss1
+=
c_loss
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
// intra-warp reductions
for
(
int
mask
=
blockDim
.
x
/
2
;
mask
>
0
;
mask
/=
2
)
{
sum_loss1
+=
WARP_SHFL_XOR
(
sum_loss1
,
mask
);
sum_loss2
+=
WARP_SHFL_XOR
(
sum_loss2
,
mask
);
}
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_i
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
2
*
wrt_i
]
=
sum_loss1
;
buf
[
2
*
wrt_i
+
1
]
=
sum_loss2
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
y
<
offset
)
{
const
int
read_i
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_loss1
+=
buf
[
2
*
read_i
];
sum_loss2
+=
buf
[
2
*
read_i
+
1
];
}
__syncthreads
();
}
if
(
threadIdx
.
y
==
0
)
{
buf
[
2
*
threadIdx
.
x
]
=
sum_loss1
;
buf
[
2
*
threadIdx
.
x
+
1
]
=
sum_loss2
;
}
__syncthreads
();
if
(
threadIdx
.
y
!=
0
)
{
sum_loss1
=
buf
[
2
*
threadIdx
.
x
];
sum_loss2
=
buf
[
2
*
threadIdx
.
x
+
1
];
}
}
// all threads now have the two sums over l
U
fH
=
(
U
)
n2
;
U
term1
=
(
U
(
1
)
/
fH
)
*
c_invvar
;
T
*
k_grad_input
=
grad_input
+
i1
*
n2
;
if
(
gamma
!=
NULL
)
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
U
f_grad_input
=
fH
*
c_loss
*
gamma
[
l
];
f_grad_input
-=
sum_loss1
;
f_grad_input
-=
(
c_h
-
c_mean
)
*
c_invvar
*
sum_loss2
;
f_grad_input
*=
term1
;
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
);
}
}
else
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
U
f_grad_input
=
fH
*
c_loss
;
f_grad_input
-=
sum_loss1
;
f_grad_input
-=
(
c_h
-
c_mean
)
*
c_invvar
*
sum_loss2
;
f_grad_input
*=
term1
;
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
);
}
}
// prevent race where buf is written again before reads are done
__syncthreads
();
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostApplyLayerNorm
(
V
*
output
,
U
*
mean
,
U
*
invvar
,
const
T
*
input
,
int
n1
,
int
n2
,
double
epsilon
,
const
V
*
gamma
,
const
V
*
beta
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
dim3
threads
(
32
,
4
,
1
);
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
int
nshared
=
threads
.
y
>
1
?
threads
.
y
*
sizeof
(
U
)
+
(
threads
.
y
/
2
)
*
sizeof
(
U
)
:
0
;
cuApplyLayerNorm
<<<
blocks
,
threads
,
nshared
,
stream
>>>
(
output
,
mean
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
,
beta
);
}
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
)
{
using
namespace
at
;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES
(
input
->
scalar_type
(),
output
->
scalar_type
(),
"cuda_layer_norm_kernel"
,
HostApplyLayerNorm
(
output
->
DATA_PTR
<
scalar_t_out
>
(),
mean
->
DATA_PTR
<
float
>
(),
invvar
->
DATA_PTR
<
float
>
(),
input
->
DATA_PTR
<
scalar_t_in
>
(),
n1
,
n2
,
epsilon
,
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
beta
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
);
)
}
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostLayerNormGradient
(
const
V
*
dout
,
const
U
*
mean
,
const
U
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
const
V
*
gamma
,
const
V
*
beta
,
double
epsilon
,
T
*
grad_input
,
V
*
grad_gamma
,
V
*
grad_beta
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
// compute grad_gamma(j) and grad_beta(j)
const
int
part_size
=
16
;
const
dim3
threads2
(
32
,
4
,
1
);
const
dim3
blocks2
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
part_size
,
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
at
::
Tensor
part_grad_gamma
=
at
::
empty
(
{
part_size
,
n2
},
input
->
options
().
dtype
(
at
::
ScalarType
::
Float
));
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
dout
,
input
->
DATA_PTR
<
T
>
(),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
part_grad_gamma
.
DATA_PTR
<
U
>
(),
part_grad_beta
.
DATA_PTR
<
U
>
());
const
dim3
threads3
(
32
,
8
,
1
);
const
dim3
blocks3
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
1
,
1
);
const
int
nshared3
=
threads3
.
x
*
threads3
.
y
*
sizeof
(
U
);
cuComputeGradGammaBeta
<<<
blocks3
,
threads3
,
nshared3
,
stream
>>>
(
part_grad_gamma
.
DATA_PTR
<
U
>
(),
part_grad_beta
.
DATA_PTR
<
U
>
(),
part_size
,
n1
,
n2
,
grad_gamma
,
grad_beta
);
}
// compute grad_input
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks1
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
threads1
(
32
,
4
,
1
);
int
nshared
=
threads1
.
y
>
1
?
threads1
.
y
*
threads1
.
x
*
sizeof
(
U
)
:
0
;
cuComputeGradInput
<<<
blocks1
,
threads1
,
nshared
,
stream
>>>
(
dout
,
input
->
DATA_PTR
<
T
>
(),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
gamma
,
grad_input
);
}
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
,
at
::
Tensor
*
grad_input
,
at
::
Tensor
*
grad_gamma
,
at
::
Tensor
*
grad_beta
)
{
using
namespace
at
;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES
(
input
->
scalar_type
(),
gamma
->
scalar_type
(),
"cuda_layer_norm_gradient_kernel"
,
HostLayerNormGradient
(
dout
->
DATA_PTR
<
scalar_t_out
>
(),
mean
->
DATA_PTR
<
float
>
(),
invvar
->
DATA_PTR
<
float
>
(),
input
,
n1
,
n2
,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
gamma
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
epsilon
,
grad_input
->
DATA_PTR
<
scalar_t_in
>
(),
gamma
!=
NULL
?
grad_gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
gamma
!=
NULL
?
grad_beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
);
)
}
megatron/fused_kernels/tests/test_fused_kernels.py
View file @
d650e6a2
...
@@ -11,7 +11,7 @@ from megatron.fused_kernels import load
...
@@ -11,7 +11,7 @@ from megatron.fused_kernels import load
def
test_load_fused_kernels
():
def
test_load_fused_kernels
():
try
:
try
:
import
fused_
mix_prec_
layer_norm_cuda
import
fused_layer_norm_cuda
import
scaled_masked_softmax_cuda
import
scaled_masked_softmax_cuda
import
scaled_upper_triang_masked_softmax_cuda
import
scaled_upper_triang_masked_softmax_cuda
import
torch
import
torch
...
@@ -21,7 +21,6 @@ def test_load_fused_kernels():
...
@@ -21,7 +21,6 @@ def test_load_fused_kernels():
print
(
"[Fail] load_fused_kernels"
)
print
(
"[Fail] load_fused_kernels"
)
raise
e
raise
e
def
test_fused_softmax
():
def
test_fused_softmax
():
bert
=
BertModel
.
from_pretrained
(
"bert-base-cased"
).
cuda
().
half
()
bert
=
BertModel
.
from_pretrained
(
"bert-base-cased"
).
cuda
().
half
()
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-cased"
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-cased"
)
...
@@ -328,7 +327,7 @@ def test_masked_softmax_backward():
...
@@ -328,7 +327,7 @@ def test_masked_softmax_backward():
def
test_allmasked_softmax_forward
():
def
test_allmasked_softmax_forward
():
import
scaled_masked_softmax_cuda
import
scaled_masked_softmax_cuda
batch
=
2
batch
=
2
attn
=
16
attn
=
16
...
@@ -345,7 +344,7 @@ def test_allmasked_softmax_forward():
...
@@ -345,7 +344,7 @@ def test_allmasked_softmax_forward():
def
test_allmasked_softmax_backward
():
def
test_allmasked_softmax_backward
():
import
scaled_masked_softmax_cuda
import
scaled_masked_softmax_cuda
batch
=
2
batch
=
2
attn
=
16
attn
=
16
scale_t
=
torch
.
tensor
([
1.0
])
scale_t
=
torch
.
tensor
([
1.0
])
...
...
megatron/model/fused_layer_norm.py
View file @
d650e6a2
...
@@ -18,40 +18,11 @@ try:
...
@@ -18,40 +18,11 @@ try:
except
:
except
:
HAVE_PERSIST_LAYER_NORM
=
False
HAVE_PERSIST_LAYER_NORM
=
False
global
fused_mix_prec_layer_norm_cuda
from
apex.normalization.fused_layer_norm
import
FusedLayerNormAffineFunction
fused_mix_prec_layer_norm_cuda
=
None
class
FusedLayerNormAffineFunction
(
torch
.
autograd
.
Function
):
global
fused_layer_norm_cuda
fused_layer_norm_cuda
=
None
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
,
normalized_shape
,
eps
):
ctx
.
normalized_shape
=
normalized_shape
ctx
.
eps
=
eps
input_
=
input
.
contiguous
()
weight_
=
weight
.
contiguous
()
bias_
=
bias
.
contiguous
()
output
,
mean
,
invvar
=
fused_mix_prec_layer_norm_cuda
.
forward_affine
(
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
ctx
.
save_for_backward
(
input_
,
weight_
,
bias_
,
mean
,
invvar
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input_
,
weight_
,
bias_
,
mean
,
invvar
=
ctx
.
saved_tensors
grad_input
=
grad_weight
=
grad_bias
=
None
grad_input
,
grad_weight
,
grad_bias
\
=
fused_mix_prec_layer_norm_cuda
.
backward_affine
(
grad_output
.
contiguous
(),
mean
,
invvar
,
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
...
@@ -64,9 +35,8 @@ class MixedFusedLayerNorm(torch.nn.Module):
...
@@ -64,9 +35,8 @@ class MixedFusedLayerNorm(torch.nn.Module):
self
.
apply_layernorm_1p
=
apply_layernorm_1p
self
.
apply_layernorm_1p
=
apply_layernorm_1p
global
fused_mix_prec_layer_norm_cuda
global
fused_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
fused_layer_norm_cuda
=
importlib
.
import_module
(
"fused_layer_norm_cuda"
)
"fused_mix_prec_layer_norm_cuda"
)
# List of hiddens sizes supported in the persistent layer norm kernel
# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
# If the hidden size is not supported, fall back to the non-persistent
...
@@ -87,7 +57,7 @@ class MixedFusedLayerNorm(torch.nn.Module):
...
@@ -87,7 +57,7 @@ class MixedFusedLayerNorm(torch.nn.Module):
self
.
reset_parameters
()
self
.
reset_parameters
()
self
.
no_persist_layer_norm
=
no_persist_layer_norm
self
.
no_persist_layer_norm
=
no_persist_layer_norm
self
.
sequence_parallel
=
sequence_parallel
self
.
sequence_parallel
=
sequence_parallel
# set sequence parallelism flag on weight and bias parameters
# set sequence parallelism flag on weight and bias parameters
setattr
(
self
.
weight
,
'sequence_parallel'
,
self
.
sequence_parallel
)
setattr
(
self
.
weight
,
'sequence_parallel'
,
self
.
sequence_parallel
)
setattr
(
self
.
bias
,
'sequence_parallel'
,
self
.
sequence_parallel
)
setattr
(
self
.
bias
,
'sequence_parallel'
,
self
.
sequence_parallel
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment