Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
8bed1d63
"vscode:/vscode.git/clone" did not exist on "afdbe5d3736f156e2a2c0afd13891f47a416baf5"
Commit
8bed1d63
authored
Dec 21, 2020
by
mohammad
Committed by
Deepak Narayanan
Dec 22, 2020
Browse files
Add residual connection in fp32 machinery
parent
62632d39
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1556 additions
and
0 deletions
+1556
-0
megatron/arguments.py
megatron/arguments.py
+4
-0
megatron/fused_kernels/__init__.py
megatron/fused_kernels/__init__.py
+26
-0
megatron/fused_kernels/compat.h
megatron/fused_kernels/compat.h
+31
-0
megatron/fused_kernels/layer_norm_cuda.cpp
megatron/fused_kernels/layer_norm_cuda.cpp
+260
-0
megatron/fused_kernels/layer_norm_cuda_kernel.cu
megatron/fused_kernels/layer_norm_cuda_kernel.cu
+829
-0
megatron/fused_kernels/type_shim.h
megatron/fused_kernels/type_shim.h
+227
-0
megatron/model/fused_layer_norm.py
megatron/model/fused_layer_norm.py
+179
-0
No files found.
megatron/arguments.py
View file @
8bed1d63
...
@@ -200,6 +200,10 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -200,6 +200,10 @@ def parse_args(extra_args_provider=None, defaults={},
if
args
.
scaled_masked_softmax_fusion
:
if
args
.
scaled_masked_softmax_fusion
:
fused_kernels
.
load_scaled_masked_softmax_fusion_kernel
()
fused_kernels
.
load_scaled_masked_softmax_fusion_kernel
()
# Load mixed precision fused layer norm.
if
args
.
fp32_residual_connection
:
fused_kernels
.
load_fused_mix_prec_layer_norm_kernel
()
_print_args
(
args
)
_print_args
(
args
)
return
args
return
args
...
...
megatron/fused_kernels/__init__.py
View file @
8bed1d63
...
@@ -98,3 +98,29 @@ def load_scaled_masked_softmax_fusion_kernel():
...
@@ -98,3 +98,29 @@ def load_scaled_masked_softmax_fusion_kernel():
'--expt-relaxed-constexpr'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
cc_flag
)
'--use_fast_math'
]
+
cc_flag
)
def
load_fused_mix_prec_layer_norm_kernel
():
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
buildpath
=
srcpath
/
'build'
create_build_dir
(
buildpath
)
fused_mix_prec_layer_norm_cuda
=
cpp_extension
.
load
(
name
=
'fused_mix_prec_layer_norm_cuda'
,
sources
=
[
srcpath
/
'layer_norm_cuda.cpp'
,
srcpath
/
'layer_norm_cuda_kernel.cu'
],
build_directory
=
buildpath
,
extra_cflags
=
[
'-O3'
],
extra_cuda_cflags
=
[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-maxrregcount=50'
,
'--use_fast_math'
]
+
cc_flag
)
megatron/fused_kernels/compat.h
0 → 100644
View file @
8bed1d63
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
megatron/fused_kernels/layer_norm_cuda.cpp
0 → 100644
View file @
8bed1d63
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <torch/extension.h>
#include <vector>
#include <cassert>
#include "compat.h"
namespace
{
void
compute_n1_n2
(
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
int
&
n1
,
int
&
n2
)
{
int
idiff
=
input
.
ndimension
()
-
normalized_shape
.
size
();
n2
=
1
;
for
(
int
i
=
0
;
i
<
(
int
)
normalized_shape
.
size
();
++
i
)
{
assert
(
input
.
sizes
()[
i
+
idiff
]
==
normalized_shape
[
i
]
);
n2
*=
normalized_shape
[
i
];
}
n1
=
1
;
for
(
int
i
=
0
;
i
<
idiff
;
++
i
)
{
n1
*=
input
.
sizes
()[
i
];
}
}
void
check_args
(
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
beta
)
{
TORCH_CHECK
(
!
gamma
.
defined
()
||
gamma
.
sizes
().
equals
(
normalized_shape
));
TORCH_CHECK
(
!
beta
.
defined
()
||
beta
.
sizes
().
equals
(
normalized_shape
));
}
void
check_args
(
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
int
&
n1
,
int
&
n2
)
{
int64_t
normalized_ndim
=
normalized_shape
.
size
();
if
(
normalized_ndim
<
1
)
{
std
::
stringstream
ss
;
ss
<<
"Expected normalized_shape to be at least 1-dimensional, i.e., "
<<
"containing at least one element, but got normalized_shape="
<<
normalized_shape
;
throw
std
::
runtime_error
(
ss
.
str
());
}
auto
input_shape
=
input
.
sizes
();
auto
input_ndim
=
input
.
dim
();
if
(
input_ndim
<
normalized_ndim
||
!
input_shape
.
slice
(
input_ndim
-
normalized_ndim
).
equals
(
normalized_shape
))
{
std
::
stringstream
ss
;
ss
<<
"Given normalized_shape="
<<
normalized_shape
<<
", expected input with shape [*"
;
for
(
auto
size
:
normalized_shape
)
{
ss
<<
", "
<<
size
;
}
ss
<<
"], but got input of size"
<<
input_shape
;
throw
std
::
runtime_error
(
ss
.
str
());
}
compute_n1_n2
(
input
,
normalized_shape
,
n1
,
n2
);
}
void
check_args
(
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
int
&
n1
,
int
&
n2
)
{
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
check_args
(
normalized_shape
,
gamma
,
beta
);
}
}
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
at
::
Tensor
>
layer_norm
(
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
double
epsilon
)
{
CHECK_INPUT
(
input
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
);
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
NULL
,
NULL
,
epsilon
);
return
{
output
,
mean
,
invvar
};
}
std
::
vector
<
at
::
Tensor
>
layer_norm_affine
(
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
double
epsilon
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
,
input
.
options
().
dtype
(
at
::
ScalarType
::
Half
));
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
);
return
{
output
,
mean
,
invvar
};
}
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
,
at
::
Tensor
*
grad_input
,
at
::
Tensor
*
grad_gamma
,
at
::
Tensor
*
grad_beta
);
at
::
Tensor
layer_norm_gradient
(
at
::
Tensor
dout
,
at
::
Tensor
mean
,
at
::
Tensor
invvar
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
double
epsilon
)
{
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
invvar
);
CHECK_INPUT
(
input
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
NULL
,
NULL
,
epsilon
,
&
grad_input
,
NULL
,
NULL
);
return
grad_input
;
}
std
::
vector
<
at
::
Tensor
>
layer_norm_gradient_affine
(
at
::
Tensor
dout
,
at
::
Tensor
mean
,
at
::
Tensor
invvar
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
double
epsilon
)
{
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
invvar
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
at
::
Tensor
grad_gamma
=
at
::
empty_like
(
gamma
);
at
::
Tensor
grad_beta
=
at
::
empty_like
(
beta
);
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
,
&
grad_input
,
&
grad_gamma
,
&
grad_beta
);
return
{
grad_input
,
grad_gamma
,
grad_beta
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_affine"
,
&
layer_norm_affine
,
"LayerNorm forward (CUDA)"
);
m
.
def
(
"forward"
,
&
layer_norm
,
"LayerNorm forward (CUDA)"
);
m
.
def
(
"backward_affine"
,
&
layer_norm_gradient_affine
,
"LayerNorm backward (CUDA)"
);
m
.
def
(
"backward"
,
&
layer_norm_gradient
,
"LayerNorm backward (CUDA)"
);
}
megatron/fused_kernels/layer_norm_cuda_kernel.cu
0 → 100644
View file @
8bed1d63
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include <THC/THCDeviceUtils.cuh>
#include <cuda.h>
#include <cuda_runtime.h>
#include "type_shim.h"
template
<
typename
U
>
__device__
void
cuWelfordOnlineSum
(
const
U
curr
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
count
=
count
+
U
(
1
);
U
delta
=
curr
-
mu
;
U
lmean
=
mu
+
delta
/
count
;
mu
=
lmean
;
U
delta2
=
curr
-
lmean
;
sigma2
=
sigma2
+
delta
*
delta2
;
}
template
<
typename
U
>
__device__
void
cuChanOnlineSum
(
const
U
muB
,
const
U
sigma2B
,
const
U
countB
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
U
delta
=
muB
-
mu
;
U
nA
=
count
;
U
nB
=
countB
;
count
=
count
+
countB
;
U
nX
=
count
;
if
(
nX
>
U
(
0
))
{
nA
=
nA
/
nX
;
nB
=
nB
/
nX
;
mu
=
nA
*
mu
+
nB
*
muB
;
sigma2
=
sigma2
+
sigma2B
+
delta
*
delta
*
nA
*
nB
*
nX
;
}
else
{
mu
=
U
(
0
);
sigma2
=
U
(
0
);
}
}
template
<
typename
T
,
typename
U
>
__device__
void
cuWelfordMuSigma2
(
const
T
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
int
i1
,
U
&
mu
,
U
&
sigma2
,
U
*
buf
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
U
count
=
U
(
0
);
mu
=
U
(
0
);
sigma2
=
U
(
0
);
if
(
i1
<
n1
)
{
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
T
*
lvals
=
vals
+
i1
*
n2
;
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
l
+
k
]);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
}
}
for
(;
l
<
n2
;
++
l
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
l
]);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
}
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
U
muB
=
WARP_SHFL
(
mu
,
srcLaneB
);
U
countB
=
WARP_SHFL
(
count
,
srcLaneB
);
U
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
);
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
U
*
ubuf
=
(
U
*
)
buf
;
U
*
ibuf
=
(
U
*
)(
ubuf
+
blockDim
.
y
);
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ibuf
[
wrt_y
]
=
count
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
U
muB
=
ubuf
[
2
*
threadIdx
.
y
];
U
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
U
countB
=
ibuf
[
threadIdx
.
y
];
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
__syncthreads
();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
ubuf
[
0
]
=
mu
;
ubuf
[
1
]
=
sigma2
;
}
__syncthreads
();
mu
=
ubuf
[
0
];
sigma2
=
ubuf
[
1
]
/
U
(
n2
);
// don't care about final value of count, we know count == n2
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
);
sigma2
=
WARP_SHFL
(
sigma2
/
U
(
n2
),
0
);
}
}
}
template
<
>
__device__
void
cuWelfordMuSigma2
(
const
at
::
Half
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
int
i1
,
float
&
mu
,
float
&
sigma2
,
float
*
buf
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
float
count
=
0.0
f
;
mu
=
float
(
0
);
sigma2
=
float
(
0
);
if
(
i1
<
n1
)
{
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
at
::
Half
*
lvals
=
vals
+
i1
*
n2
;
int
l
=
8
*
thrx
;
if
((((
size_t
)
lvals
)
&
3
)
!=
0
)
{
// 16 bit alignment
// first thread consumes first point
if
(
thrx
==
0
)
{
float
curr
=
static_cast
<
float
>
(
lvals
[
0
]);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
}
++
l
;
}
// at this point, lvals[l] are 32 bit aligned for all threads.
for
(;
l
+
7
<
n2
;
l
+=
8
*
numx
)
{
for
(
int
k
=
0
;
k
<
8
;
k
+=
2
)
{
float2
curr
=
__half22float2
(
*
((
__half2
*
)(
lvals
+
l
+
k
)));
cuWelfordOnlineSum
(
curr
.
x
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
.
y
,
mu
,
sigma2
,
count
);
}
}
for
(;
l
<
n2
;
++
l
)
{
float
curr
=
static_cast
<
float
>
(
lvals
[
l
]);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
}
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
float
muB
=
WARP_SHFL
(
mu
,
srcLaneB
);
float
countB
=
WARP_SHFL
(
count
,
srcLaneB
);
float
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
);
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
float
*
ubuf
=
(
float
*
)
buf
;
float
*
ibuf
=
(
float
*
)(
ubuf
+
blockDim
.
y
);
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ibuf
[
wrt_y
]
=
count
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
float
muB
=
ubuf
[
2
*
threadIdx
.
y
];
float
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
float
countB
=
ibuf
[
threadIdx
.
y
];
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
__syncthreads
();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
ubuf
[
0
]
=
mu
;
ubuf
[
1
]
=
sigma2
;
}
__syncthreads
();
mu
=
ubuf
[
0
];
sigma2
=
ubuf
[
1
]
/
float
(
n2
);
// don't care about final value of count, we know count == n2
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
);
sigma2
=
WARP_SHFL
(
sigma2
/
float
(
n2
),
0
);
}
}
}
template
<
typename
U
>
U
rsqrt
(
U
v
)
{
return
U
(
1
)
/
sqrt
(
v
);
}
template
<
>
float
rsqrt
(
float
v
)
{
return
rsqrtf
(
v
);
}
template
<
>
double
rsqrt
(
double
v
)
{
return
rsqrt
(
v
);
}
namespace
{
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
// template <typename T>
// struct SharedMemory
// {
// // Ensure that we won't compile any un-specialized types
// __device__ T *getPointer()
// {
// extern __device__ void error(void);
// error();
// return NULL;
// }
// };
// https://github.com/NVIDIA/apex/issues/246
template
<
typename
T
>
struct
SharedMemory
;
template
<
>
struct
SharedMemory
<
float
>
{
__device__
float
*
getPointer
()
{
extern
__shared__
float
s_float
[];
return
s_float
;
}
};
template
<
>
struct
SharedMemory
<
double
>
{
__device__
double
*
getPointer
()
{
extern
__shared__
double
s_double
[];
return
s_double
;
}
};
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
void
cuApplyLayerNorm
(
V
*
__restrict__
output_vals
,
U
*
__restrict__
mean
,
U
*
__restrict__
invvar
,
const
T
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
U
epsilon
,
const
V
*
__restrict__
gamma
,
const
V
*
__restrict__
beta
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
U
mu
,
sigma2
;
cuWelfordMuSigma2
(
vals
,
n1
,
n2
,
i1
,
mu
,
sigma2
,
buf
);
const
T
*
lvals
=
vals
+
i1
*
n2
;
V
*
ovals
=
output_vals
+
i1
*
n2
;
U
c_invvar
=
rsqrt
(
sigma2
+
epsilon
);
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
ovals
[
i
]
=
gamma
[
i
]
*
static_cast
<
V
>
(
c_invvar
*
(
curr
-
mu
))
+
beta
[
i
];
}
}
else
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
ovals
[
i
]
=
static_cast
<
V
>
(
c_invvar
*
(
curr
-
mu
));
}
}
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
mean
[
i1
]
=
mu
;
invvar
[
i1
]
=
c_invvar
;
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__device__
void
cuLoadWriteStridedInputs
(
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_col_off
,
const
int
i2_off
,
const
int
row_stride
,
U
*
warp_buf1
,
U
*
warp_buf2
,
const
T
*
input
,
const
V
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
<
i1_end
)
{
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
i2
=
i2_off
+
k
;
int
load_idx
=
i1
*
n2
+
i2
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
=
curr_dout
;
warp_buf2
[
write_idx
]
=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
else
{
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
}
}
}
else
{
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__device__
void
cuLoadAddStridedInputs
(
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_col_off
,
const
int
i2_off
,
const
int
row_stride
,
U
*
warp_buf1
,
U
*
warp_buf2
,
const
T
*
input
,
const
V
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
<
i1_end
)
{
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
i2
=
i2_off
+
k
;
int
load_idx
=
i1
*
n2
+
i2
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
+=
curr_dout
;
warp_buf2
[
write_idx
]
+=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
void
cuComputePartGradGammaBeta
(
const
V
*
__restrict__
dout
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
U
*
part_grad_gamma
,
U
*
part_grad_beta
)
{
const
int
numsegs_n1
=
(
n1
+
blockDim
.
y
*
blockDim
.
y
-
1
)
/
(
blockDim
.
y
*
blockDim
.
y
);
const
int
segs_per_block
=
(
numsegs_n1
+
gridDim
.
y
-
1
)
/
gridDim
.
y
;
const
int
i1_beg
=
blockIdx
.
y
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
i1_beg_plus_one
=
(
blockIdx
.
y
+
1
)
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
i1_end
=
i1_beg_plus_one
<
n1
?
i1_beg_plus_one
:
n1
;
const
int
row_stride
=
blockDim
.
x
+
1
;
const
int
thr_load_col_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
&
(
blockDim
.
x
-
1
);
const
int
thr_load_row_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
/
blockDim
.
x
+
threadIdx
.
y
*
blockDim
.
y
;
const
int
i2_off
=
blockIdx
.
x
*
blockDim
.
x
+
thr_load_col_off
;
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
// buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
U
*
warp_buf1
=
(
U
*
)
buf
;
U
*
warp_buf2
=
warp_buf1
+
blockDim
.
y
*
blockDim
.
y
*
row_stride
;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs
(
i1_beg
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
for
(
int
i1_block
=
i1_beg
+
blockDim
.
y
*
blockDim
.
y
;
i1_block
<
i1_end
;
i1_block
+=
blockDim
.
y
*
blockDim
.
y
)
{
cuLoadAddStridedInputs
(
i1_block
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
}
__syncthreads
();
// inter-warp reductions
// sum within each warp
U
acc1
=
U
(
0
);
U
acc2
=
U
(
0
);
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
row1
=
threadIdx
.
y
+
k
*
blockDim
.
y
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
acc1
+=
warp_buf1
[
idx1
];
acc2
+=
warp_buf2
[
idx1
];
}
warp_buf1
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc1
;
warp_buf2
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc2
;
__syncthreads
();
// sum all warps
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
1
;
offset
/=
2
)
{
if
(
threadIdx
.
y
<
offset
)
{
int
row1
=
threadIdx
.
y
;
int
row2
=
threadIdx
.
y
+
offset
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
warp_buf1
[
idx1
]
+=
warp_buf1
[
idx2
];
warp_buf2
[
idx1
]
+=
warp_buf2
[
idx2
];
}
__syncthreads
();
}
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
threadIdx
.
y
==
0
&&
i2
<
n2
)
{
int
row1
=
threadIdx
.
y
;
int
row2
=
threadIdx
.
y
+
1
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
part_grad_beta
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf1
[
idx1
]
+
warp_buf1
[
idx2
];
part_grad_gamma
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf2
[
idx1
]
+
warp_buf2
[
idx2
];
}
}
template
<
typename
U
,
typename
V
>
__global__
void
cuComputeGradGammaBeta
(
const
U
*
part_grad_gamma
,
const
U
*
part_grad_beta
,
const
int
part_size
,
const
int
n1
,
const
int
n2
,
V
*
grad_gamma
,
V
*
grad_beta
)
{
// sum partial gradients for gamma and beta
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i2
<
n2
)
{
// each warp does sequential reductions until reduced part_size is num_warps
int
num_warp_reductions
=
part_size
/
blockDim
.
y
;
U
sum_gamma
=
U
(
0
);
U
sum_beta
=
U
(
0
);
const
U
*
part_grad_gamma_ptr
=
part_grad_gamma
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
const
U
*
part_grad_beta_ptr
=
part_grad_beta
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
for
(
int
warp_offset
=
0
;
warp_offset
<
num_warp_reductions
;
++
warp_offset
)
{
sum_gamma
+=
part_grad_gamma_ptr
[
warp_offset
*
n2
];
sum_beta
+=
part_grad_beta_ptr
[
warp_offset
*
n2
];
}
// inter-warp reductions
const
int
nbsize3
=
blockDim
.
x
*
blockDim
.
y
/
2
;
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>=
1
;
offset
/=
2
)
{
// top half write to shared memory
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
write_idx
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
write_idx
]
=
sum_gamma
;
buf
[
write_idx
+
nbsize3
]
=
sum_beta
;
}
__syncthreads
();
// bottom half sums
if
(
threadIdx
.
y
<
offset
)
{
const
int
read_idx
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_gamma
+=
buf
[
read_idx
];
sum_beta
+=
buf
[
read_idx
+
nbsize3
];
}
__syncthreads
();
}
// write out fully summed gradients
if
(
threadIdx
.
y
==
0
)
{
grad_gamma
[
i2
]
=
sum_gamma
;
grad_beta
[
i2
]
=
sum_beta
;
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
void
cuComputeGradInput
(
const
V
*
__restrict__
dout
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
const
V
*
gamma
,
T
*
grad_input
)
{
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
U
sum_loss1
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
const
U
c_mean
=
mean
[
i1
];
const
U
c_invvar
=
invvar
[
i1
];
const
T
*
k_input
=
input
+
i1
*
n2
;
const
V
*
k_dout
=
dout
+
i1
*
n2
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
)
{
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
sum_loss1
+=
c_loss
*
gamma
[
l
+
k
];
sum_loss2
+=
c_loss
*
gamma
[
l
+
k
]
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
for
(;
l
<
n2
;
++
l
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
sum_loss1
+=
c_loss
*
gamma
[
l
];
sum_loss2
+=
c_loss
*
gamma
[
l
]
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
else
{
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
sum_loss1
+=
c_loss
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
for
(;
l
<
n2
;
++
l
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
sum_loss1
+=
c_loss
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
// intra-warp reductions
for
(
int
mask
=
blockDim
.
x
/
2
;
mask
>
0
;
mask
/=
2
)
{
sum_loss1
+=
WARP_SHFL_XOR
(
sum_loss1
,
mask
);
sum_loss2
+=
WARP_SHFL_XOR
(
sum_loss2
,
mask
);
}
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_i
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
2
*
wrt_i
]
=
sum_loss1
;
buf
[
2
*
wrt_i
+
1
]
=
sum_loss2
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
y
<
offset
)
{
const
int
read_i
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_loss1
+=
buf
[
2
*
read_i
];
sum_loss2
+=
buf
[
2
*
read_i
+
1
];
}
__syncthreads
();
}
if
(
threadIdx
.
y
==
0
)
{
buf
[
2
*
threadIdx
.
x
]
=
sum_loss1
;
buf
[
2
*
threadIdx
.
x
+
1
]
=
sum_loss2
;
}
__syncthreads
();
if
(
threadIdx
.
y
!=
0
)
{
sum_loss1
=
buf
[
2
*
threadIdx
.
x
];
sum_loss2
=
buf
[
2
*
threadIdx
.
x
+
1
];
}
}
// all threads now have the two sums over l
U
fH
=
(
U
)
n2
;
U
term1
=
(
U
(
1
)
/
fH
)
*
c_invvar
;
T
*
k_grad_input
=
grad_input
+
i1
*
n2
;
if
(
gamma
!=
NULL
)
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
U
f_grad_input
=
fH
*
c_loss
*
gamma
[
l
];
f_grad_input
-=
sum_loss1
;
f_grad_input
-=
(
c_h
-
c_mean
)
*
c_invvar
*
sum_loss2
;
f_grad_input
*=
term1
;
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
);
}
}
else
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
U
f_grad_input
=
fH
*
c_loss
;
f_grad_input
-=
sum_loss1
;
f_grad_input
-=
(
c_h
-
c_mean
)
*
c_invvar
*
sum_loss2
;
f_grad_input
*=
term1
;
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
);
}
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostApplyLayerNorm
(
V
*
output
,
U
*
mean
,
U
*
invvar
,
const
T
*
input
,
int
n1
,
int
n2
,
double
epsilon
,
const
V
*
gamma
,
const
V
*
beta
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
dim3
threads
(
32
,
4
,
1
);
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
int
nshared
=
threads
.
y
>
1
?
threads
.
y
*
sizeof
(
U
)
+
(
threads
.
y
/
2
)
*
sizeof
(
U
)
:
0
;
cuApplyLayerNorm
<<<
blocks
,
threads
,
nshared
,
stream
>>>
(
output
,
mean
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
,
beta
);
}
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
)
{
using
namespace
at
;
DISPATCH_DOUBLE_FLOAT_AND_HALF
(
input
->
scalar_type
(),
0
,
"layer_norm_cuda_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
using
output_t
=
at
::
Half
;
HostApplyLayerNorm
(
output
->
DATA_PTR
<
output_t
>
(),
mean
->
DATA_PTR
<
accscalar_t
>
(),
invvar
->
DATA_PTR
<
accscalar_t
>
(),
input
->
DATA_PTR
<
scalar_t_0
>
(),
n1
,
n2
,
epsilon
,
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
output_t
>
()
:
NULL
,
beta
!=
NULL
?
beta
->
DATA_PTR
<
output_t
>
()
:
NULL
);
)
}
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostLayerNormGradient
(
const
V
*
dout
,
const
U
*
mean
,
const
U
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
const
V
*
gamma
,
const
V
*
beta
,
double
epsilon
,
T
*
grad_input
,
V
*
grad_gamma
,
V
*
grad_beta
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
// compute grad_gamma(j) and grad_beta(j)
const
int
part_size
=
16
;
const
dim3
threads2
(
32
,
4
,
1
);
const
dim3
blocks2
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
part_size
,
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
at
::
Tensor
part_grad_gamma
=
at
::
empty
({
part_size
,
n2
},
input
->
options
().
dtype
(
input
->
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
->
scalar_type
()));
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
dout
,
input
->
DATA_PTR
<
T
>
(),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
part_grad_gamma
.
DATA_PTR
<
U
>
(),
part_grad_beta
.
DATA_PTR
<
U
>
());
const
dim3
threads3
(
32
,
8
,
1
);
const
dim3
blocks3
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
1
,
1
);
const
int
nshared3
=
threads3
.
x
*
threads3
.
y
*
sizeof
(
U
);
cuComputeGradGammaBeta
<<<
blocks3
,
threads3
,
nshared3
,
stream
>>>
(
part_grad_gamma
.
DATA_PTR
<
U
>
(),
part_grad_beta
.
DATA_PTR
<
U
>
(),
part_size
,
n1
,
n2
,
grad_gamma
,
grad_beta
);
}
// compute grad_input
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks1
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
threads1
(
32
,
4
,
1
);
int
nshared
=
threads1
.
y
>
1
?
threads1
.
y
*
threads1
.
x
*
sizeof
(
U
)
:
0
;
cuComputeGradInput
<<<
blocks1
,
threads1
,
nshared
,
stream
>>>
(
dout
,
input
->
DATA_PTR
<
T
>
(),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
gamma
,
grad_input
);
}
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
,
at
::
Tensor
*
grad_input
,
at
::
Tensor
*
grad_gamma
,
at
::
Tensor
*
grad_beta
)
{
using
namespace
at
;
DISPATCH_FLOAT_AND_HALF
(
input
->
scalar_type
(),
0
,
"cuComputeGradInput"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
using
output_t
=
at
::
Half
;
HostLayerNormGradient
(
dout
->
DATA_PTR
<
output_t
>
(),
mean
->
DATA_PTR
<
accscalar_t
>
(),
invvar
->
DATA_PTR
<
accscalar_t
>
(),
input
,
n1
,
n2
,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
output_t
>
()
:
NULL
,
gamma
!=
NULL
?
beta
->
DATA_PTR
<
output_t
>
()
:
NULL
,
epsilon
,
grad_input
->
DATA_PTR
<
scalar_t_0
>
(),
gamma
!=
NULL
?
grad_gamma
->
DATA_PTR
<
output_t
>
()
:
NULL
,
gamma
!=
NULL
?
grad_beta
->
DATA_PTR
<
output_t
>
()
:
NULL
);
)
}
megatron/fused_kernels/type_shim.h
0 → 100644
View file @
8bed1d63
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <ATen/ATen.h>
#include "compat.h"
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: \
{ \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes_max_op
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
i
]));
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
32
]));
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down_sync
(
0xffffffff
,
final
,
i
)));
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
megatron/model/fused_layer_norm.py
0 → 100755
View file @
8bed1d63
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This code is copied fron NVIDIA apex:
https://github.com/NVIDIA/apex
with minor changes. """
import
math
import
torch
import
numbers
from
torch.nn.parameter
import
Parameter
from
torch.nn
import
init
from
torch.nn
import
functional
as
F
import
importlib
global
fused_layer_norm_cuda
fused_layer_norm_cuda
=
None
global
fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
None
class
FusedLayerNormAffineFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
,
normalized_shape
,
eps
):
global
fused_mix_prec_layer_norm_cuda
if
fused_mix_prec_layer_norm_cuda
is
None
:
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
"fused_mix_prec_layer_norm_cuda"
)
ctx
.
normalized_shape
=
normalized_shape
ctx
.
eps
=
eps
input_
=
input
.
contiguous
()
weight_
=
weight
.
contiguous
()
bias_
=
bias
.
contiguous
()
output
,
mean
,
invvar
=
fused_mix_prec_layer_norm_cuda
.
forward_affine
(
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
ctx
.
save_for_backward
(
input_
,
weight_
,
bias_
,
mean
,
invvar
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input_
,
weight_
,
bias_
,
mean
,
invvar
=
ctx
.
saved_tensors
grad_input
=
grad_weight
=
grad_bias
=
None
grad_input
,
grad_weight
,
grad_bias
=
fused_mix_prec_layer_norm_cuda
.
backward_affine
(
grad_output
.
contiguous
(),
mean
,
invvar
,
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
class
FusedLayerNormFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
normalized_shape
,
eps
):
global
fused_layer_norm_cuda
if
fused_layer_norm_cuda
is
None
:
fused_layer_norm_cuda
=
importlib
.
import_module
(
"fused_layer_norm_cuda"
)
ctx
.
normalized_shape
=
normalized_shape
ctx
.
eps
=
eps
input_
=
input
.
contiguous
()
output
,
mean
,
invvar
=
fused_layer_norm_cuda
.
forward
(
input_
,
ctx
.
normalized_shape
,
ctx
.
eps
)
ctx
.
save_for_backward
(
input_
,
mean
,
invvar
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input_
,
mean
,
invvar
=
ctx
.
saved_tensors
grad_input
=
None
grad_input
=
fused_layer_norm_cuda
.
backward
(
grad_output
.
contiguous
(),
mean
,
invvar
,
input_
,
ctx
.
normalized_shape
,
ctx
.
eps
)
return
grad_input
,
None
,
None
def
fused_layer_norm_affine
(
input
,
normalized_shape
,
weight
,
bias
,
eps
=
1e-6
):
return
FusedLayerNormAffineFunction
.
apply
(
input
,
weight
,
bias
,
normalized_shape
,
eps
)
def
fused_layer_norm
(
input
,
normalized_shape
,
eps
=
1e-6
):
return
FusedLayerNormFunction
.
apply
(
input
,
normalized_shape
,
eps
)
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
r
"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ .
Currently only runs on cuda() tensors.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, Layer Normalization applies per-element scale and
bias with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1]
\times \ldots \times \text{normalized}\_\text{shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)
Examples::
>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = apex.normalization.FusedLayerNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = apex.normalization.FusedLayerNorm(10)
>>> # Activating the module
>>> output = m(input)
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
"""
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
elementwise_affine
=
True
):
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
global
fused_layer_norm_cuda
fused_layer_norm_cuda
=
importlib
.
import_module
(
"fused_layer_norm_cuda"
)
global
fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
"fused_mix_prec_layer_norm_cuda"
)
if
isinstance
(
normalized_shape
,
numbers
.
Integral
):
normalized_shape
=
(
normalized_shape
,)
self
.
normalized_shape
=
torch
.
Size
(
normalized_shape
)
self
.
eps
=
eps
self
.
elementwise_affine
=
elementwise_affine
if
self
.
elementwise_affine
:
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
else
:
self
.
register_parameter
(
'weight'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
if
self
.
elementwise_affine
:
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
):
if
not
input
.
is_cuda
:
return
F
.
layer_norm
(
input
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
if
self
.
elementwise_affine
:
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
else
:
return
FusedLayerNormFunction
.
apply
(
input
,
self
.
normalized_shape
,
self
.
eps
)
def
extra_repr
(
self
):
return
'{normalized_shape}, eps={eps}, '
\
'elementwise_affine={elementwise_affine}'
.
format
(
**
self
.
__dict__
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment