Commit c84f0752 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Merge branch 'fp32_residual_conn' into 'main'

Option to run residual connections in fp32

See merge request ADLR/megatron-lm!195
parents 9b174da8 83671bbf
...@@ -183,6 +183,9 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -183,6 +183,9 @@ def parse_args(extra_args_provider=None, defaults={},
# Mixed precision checks. # Mixed precision checks.
if args.fp16_lm_cross_entropy: if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
if args.fp32_residual_connection:
assert args.fp16, \
'residual connection in fp32 only supported when using fp16.'
# Activation checkpointing. # Activation checkpointing.
if args.distribute_checkpointed_activations: if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \ assert args.checkpoint_activations, \
...@@ -197,6 +200,10 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -197,6 +200,10 @@ def parse_args(extra_args_provider=None, defaults={},
if args.scaled_masked_softmax_fusion: if args.scaled_masked_softmax_fusion:
fused_kernels.load_scaled_masked_softmax_fusion_kernel() fused_kernels.load_scaled_masked_softmax_fusion_kernel()
# Load mixed precision fused layer norm.
if args.fp32_residual_connection:
fused_kernels.load_fused_mix_prec_layer_norm_kernel()
_print_args(args) _print_args(args)
return args return args
...@@ -435,6 +442,8 @@ def _add_mixed_precision_args(parser): ...@@ -435,6 +442,8 @@ def _add_mixed_precision_args(parser):
group.add_argument('--fp16', action='store_true', group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode.') help='Run model in fp16 mode.')
group.add_argument('--fp32-residual-connection', action='store_true',
help='Move residual connections to fp32.')
group.add_argument('--apply-query-key-layer-scaling', action='store_true', group.add_argument('--apply-query-key-layer-scaling', action='store_true',
help='Scale Q * K^T by 1 / layer-number. If this flag ' help='Scale Q * K^T by 1 / layer-number. If this flag '
'is set, then it will automatically set ' 'is set, then it will automatically set '
......
...@@ -98,3 +98,29 @@ def load_scaled_masked_softmax_fusion_kernel(): ...@@ -98,3 +98,29 @@ def load_scaled_masked_softmax_fusion_kernel():
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + cc_flag) '--use_fast_math'] + cc_flag)
def load_fused_mix_prec_layer_norm_kernel():
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / 'build'
create_build_dir(buildpath)
fused_mix_prec_layer_norm_cuda = cpp_extension.load(
name='fused_mix_prec_layer_norm_cuda',
sources=[srcpath / 'layer_norm_cuda.cpp',
srcpath / 'layer_norm_cuda_kernel.cu'],
build_directory=buildpath,
extra_cflags=['-O3'],
extra_cuda_cflags=['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-maxrregcount=50',
'--use_fast_math'] + cc_flag)
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <torch/extension.h>
#include <vector>
#include <cassert>
#include "compat.h"
namespace {
void compute_n1_n2(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
int& n1,
int& n2)
{
int idiff = input.ndimension() - normalized_shape.size();
n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
assert( input.sizes()[i+idiff] == normalized_shape[i] );
n2 *= normalized_shape[i];
}
n1 = 1;
for (int i = 0; i < idiff; ++i) {
n1 *= input.sizes()[i];
}
}
void check_args(
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma,
at::Tensor beta
)
{
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
}
void check_args(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
int& n1,
int& n2
)
{
int64_t normalized_ndim = normalized_shape.size();
if (normalized_ndim < 1) {
std::stringstream ss;
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
<< "containing at least one element, but got normalized_shape="
<< normalized_shape;
throw std::runtime_error(ss.str());
}
auto input_shape = input.sizes();
auto input_ndim = input.dim();
if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape
<< ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
}
ss << "], but got input of size" << input_shape;
throw std::runtime_error(ss.str());
}
compute_n1_n2(input,normalized_shape,n1,n2);
}
void check_args(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma,
at::Tensor beta,
int& n1,
int& n2
)
{
check_args(input,normalized_shape,n1,n2);
check_args(normalized_shape,gamma,beta);
}
}
void cuda_layer_norm(
at::Tensor* output,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma,
at::Tensor* beta,
double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
double epsilon) {
CHECK_INPUT(input);
int n1,n2;
check_args(input,normalized_shape,n1,n2);
at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,NULL,NULL,epsilon);
return {output, mean, invvar};
}
std::vector<at::Tensor> layer_norm_affine(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2);
at::Tensor output = at::empty_like(input, input.options().dtype(at::ScalarType::Half));
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon);
return {output, mean, invvar};
}
void cuda_layer_norm_gradient(
at::Tensor* dout,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma,
at::Tensor* beta,
double epsilon,
at::Tensor* grad_input,
at::Tensor* grad_gamma,
at::Tensor* grad_beta
);
at::Tensor layer_norm_gradient(
at::Tensor dout,
at::Tensor mean,
at::Tensor invvar,
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
int n1,n2;
check_args(input,normalized_shape,n1,n2);
at::Tensor grad_input = at::empty_like(input);
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
normalized_shape,NULL,NULL,epsilon,
&grad_input,NULL,NULL);
return grad_input;
}
std::vector<at::Tensor> layer_norm_gradient_affine(
at::Tensor dout,
at::Tensor mean,
at::Tensor invvar,
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2);
at::Tensor grad_input = at::empty_like(input);
at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta);
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon,
&grad_input,&grad_gamma,&grad_beta);
return {grad_input, grad_gamma, grad_beta};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");
m.def("forward", &layer_norm, "LayerNorm forward (CUDA)");
m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)");
m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)");
}
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include <THC/THCDeviceUtils.cuh>
#include <cuda.h>
#include <cuda_runtime.h>
#include "type_shim.h"
template<typename U> __device__
void cuWelfordOnlineSum(
const U curr,
U& mu,
U& sigma2,
U& count)
{
count = count + U(1);
U delta = curr - mu;
U lmean = mu + delta / count;
mu = lmean;
U delta2 = curr - lmean;
sigma2 = sigma2 + delta * delta2;
}
template<typename U> __device__
void cuChanOnlineSum(
const U muB,
const U sigma2B,
const U countB,
U& mu,
U& sigma2,
U& count)
{
U delta = muB - mu;
U nA = count;
U nB = countB;
count = count + countB;
U nX = count;
if (nX > U(0)) {
nA = nA / nX;
nB = nB / nX;
mu = nA*mu + nB*muB;
sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
} else {
mu = U(0);
sigma2 = U(0);
}
}
template<typename T, typename U> __device__
void cuWelfordMuSigma2(
const T* __restrict__ vals,
const int n1,
const int n2,
const int i1,
U& mu,
U& sigma2,
U* buf)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
U count = U(0);
mu= U(0);
sigma2 = U(0);
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const T* lvals = vals + i1*n2;
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) {
U curr = static_cast<U>(lvals[l+k]);
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
}
}
for (; l < n2; ++l) {
U curr = static_cast<U>(lvals[l]);
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31;
U muB = WARP_SHFL(mu, srcLaneB);
U countB = WARP_SHFL(count, srcLaneB);
U sigma2B = WARP_SHFL(sigma2, srcLaneB);
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
U* ubuf = (U*)buf;
U* ibuf = (U*)(ubuf + blockDim.y);
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2*wrt_y] = mu;
ubuf[2*wrt_y+1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
U muB = ubuf[2*threadIdx.y];
U sigma2B = ubuf[2*threadIdx.y+1];
U countB = ibuf[threadIdx.y];
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
ubuf[0] = mu;
ubuf[1] = sigma2;
}
__syncthreads();
mu = ubuf[0];
sigma2 = ubuf[1]/U(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2/U(n2), 0);
}
}
}
template<> __device__
void cuWelfordMuSigma2(
const at::Half* __restrict__ vals,
const int n1,
const int n2,
const int i1,
float& mu,
float& sigma2,
float* buf)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
float count = 0.0f;
mu= float(0);
sigma2 = float(0);
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const at::Half* lvals = vals + i1*n2;
int l = 8*thrx;
if ((((size_t)lvals)&3) != 0) {
// 16 bit alignment
// first thread consumes first point
if (thrx == 0) {
float curr = static_cast<float>(lvals[0]);
cuWelfordOnlineSum(curr,mu,sigma2,count);
}
++l;
}
// at this point, lvals[l] are 32 bit aligned for all threads.
for (; l+7 < n2; l+=8*numx) {
for (int k = 0; k < 8; k+=2) {
float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
cuWelfordOnlineSum(curr.x,mu,sigma2,count);
cuWelfordOnlineSum(curr.y,mu,sigma2,count);
}
}
for (; l < n2; ++l) {
float curr = static_cast<float>(lvals[l]);
cuWelfordOnlineSum(curr,mu,sigma2,count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31;
float muB = WARP_SHFL(mu, srcLaneB);
float countB = WARP_SHFL(count, srcLaneB);
float sigma2B = WARP_SHFL(sigma2, srcLaneB);
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
float* ubuf = (float*)buf;
float* ibuf = (float*)(ubuf + blockDim.y);
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2*wrt_y] = mu;
ubuf[2*wrt_y+1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
float muB = ubuf[2*threadIdx.y];
float sigma2B = ubuf[2*threadIdx.y+1];
float countB = ibuf[threadIdx.y];
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
ubuf[0] = mu;
ubuf[1] = sigma2;
}
__syncthreads();
mu = ubuf[0];
sigma2 = ubuf[1]/float(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2/float(n2), 0);
}
}
}
template<typename U> U rsqrt(U v) {
return U(1) / sqrt(v);
}
template<> float rsqrt(float v) {
return rsqrtf(v);
}
template<> double rsqrt(double v) {
return rsqrt(v);
}
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
// template <typename T>
// struct SharedMemory
// {
// // Ensure that we won't compile any un-specialized types
// __device__ T *getPointer()
// {
// extern __device__ void error(void);
// error();
// return NULL;
// }
// };
// https://github.com/NVIDIA/apex/issues/246
template <typename T>
struct SharedMemory;
template <>
struct SharedMemory <float>
{
__device__ float *getPointer()
{
extern __shared__ float s_float[];
return s_float;
}
};
template <>
struct SharedMemory <double>
{
__device__ double *getPointer()
{
extern __shared__ double s_double[];
return s_double;
}
};
}
template<typename T, typename U, typename V> __global__
void cuApplyLayerNorm(
V* __restrict__ output_vals,
U* __restrict__ mean,
U* __restrict__ invvar,
const T* __restrict__ vals,
const int n1,
const int n2,
const U epsilon,
const V* __restrict__ gamma,
const V* __restrict__ beta
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
U mu,sigma2;
cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf);
const T* lvals = vals + i1*n2;
V* ovals = output_vals + i1*n2;
U c_invvar = rsqrt(sigma2 + epsilon);
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL && beta != NULL) {
for (int i = thrx; i < n2; i+=numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
}
} else {
for (int i = thrx; i < n2; i+=numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = static_cast<V>(c_invvar * (curr - mu));
}
}
if (threadIdx.x == 0 && threadIdx.y == 0) {
mean[i1] = mu;
invvar[i1] = c_invvar;
}
}
}
template<typename T, typename U, typename V> __device__
void cuLoadWriteStridedInputs(
const int i1_block,
const int thr_load_row_off,
const int thr_load_col_off,
const int i2_off,
const int row_stride,
U* warp_buf1,
U* warp_buf2,
const T* input,
const V* dout,
const int i1_end,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar
)
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1*n2+i2;
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
} else {
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
} else {
for (int k = 0; k < blockDim.y; ++k) {
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
}
template<typename T, typename U, typename V> __device__
void cuLoadAddStridedInputs(
const int i1_block,
const int thr_load_row_off,
const int thr_load_col_off,
const int i2_off,
const int row_stride,
U* warp_buf1,
U* warp_buf2,
const T* input,
const V* dout,
const int i1_end,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar
)
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1*n2+i2;
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
}
}
}
}
template<typename T, typename U, typename V> __global__
void cuComputePartGradGammaBeta(
const V* __restrict__ dout,
const T* __restrict__ input,
const int n1,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar,
U epsilon,
U* part_grad_gamma,
U* part_grad_beta)
{
const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
const int row_stride = blockDim.x+1;
const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
SharedMemory<U> shared;
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
U* warp_buf1 = (U*)buf;
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) {
cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
}
__syncthreads();
// inter-warp reductions
// sum within each warp
U acc1 = U(0);
U acc2 = U(0);
for (int k = 0; k < blockDim.y; ++k) {
int row1 = threadIdx.y + k*blockDim.y;
int idx1 = row1*row_stride + threadIdx.x;
acc1 += warp_buf1[idx1];
acc2 += warp_buf2[idx1];
}
warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = blockDim.y/2; offset > 1; offset /= 2) {
if (threadIdx.y < offset) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + offset;
int idx1 = row1*row_stride + threadIdx.x;
int idx2 = row2*row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2];
}
__syncthreads();
}
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (threadIdx.y == 0 && i2 < n2) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + 1;
int idx1 = row1*row_stride + threadIdx.x;
int idx2 = row2*row_stride + threadIdx.x;
part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
template<typename U, typename V> __global__
void cuComputeGradGammaBeta(
const U* part_grad_gamma,
const U* part_grad_beta,
const int part_size,
const int n1,
const int n2,
V* grad_gamma,
V* grad_beta)
{
// sum partial gradients for gamma and beta
SharedMemory<U> shared;
U* buf = shared.getPointer();
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / blockDim.y;
U sum_gamma = U(0);
U sum_beta = U(0);
const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset*n2];
sum_beta += part_grad_beta_ptr[warp_offset*n2];
}
// inter-warp reductions
const int nbsize3 = blockDim.x * blockDim.y / 2;
for (int offset = blockDim.y/2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx+nbsize3] = sum_beta;
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
sum_gamma += buf[read_idx];
sum_beta += buf[read_idx+nbsize3];
}
__syncthreads();
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
}
}
}
template<typename T, typename U, typename V> __global__
void cuComputeGradInput(
const V* __restrict__ dout,
const T* __restrict__ input,
const int n1,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar,
U epsilon,
const V* gamma,
T* grad_input)
{
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
const U c_invvar = invvar[i1];
const T* k_input = input + i1*n2;
const V* k_dout = dout + i1*n2;
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) {
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l+k]);
const U c_loss = static_cast<U>(k_dout[l+k]);
sum_loss1 += c_loss * gamma[l+k];
sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss * gamma[l];
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
}
} else {
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l+k]);
const U c_loss = static_cast<U>(k_dout[l+k]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
// intra-warp reductions
for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
}
// inter-warp reductions
if (blockDim.y > 1) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[2*wrt_i] = sum_loss1;
buf[2*wrt_i+1] = sum_loss2;
}
__syncthreads();
// lower half merges
if (threadIdx.y < offset) {
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
sum_loss1 += buf[2*read_i];
sum_loss2 += buf[2*read_i+1];
}
__syncthreads();
}
if (threadIdx.y == 0) {
buf[2*threadIdx.x] = sum_loss1;
buf[2*threadIdx.x+1] = sum_loss2;
}
__syncthreads();
if (threadIdx.y !=0) {
sum_loss1 = buf[2*threadIdx.x];
sum_loss2 = buf[2*threadIdx.x+1];
}
}
// all threads now have the two sums over l
U fH = (U)n2;
U term1 = (U(1) / fH) * c_invvar;
T* k_grad_input = grad_input + i1*n2;
if (gamma != NULL) {
for (int l = thrx; l < n2; l+=numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss * gamma[l];
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
} else {
for (int l = thrx; l < n2; l+=numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss;
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
}
}
}
template<typename T, typename U, typename V>
void HostApplyLayerNorm(
V* output,
U* mean,
U* invvar,
const T* input,
int n1,
int n2,
double epsilon,
const V* gamma,
const V* beta
)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
const dim3 threads(32,4,1);
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared =
threads.y > 1 ?
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
0;
cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
output,
mean,
invvar,
input,
n1,n2,
U(epsilon),
gamma,beta);
}
void cuda_layer_norm(
at::Tensor* output,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma,
at::Tensor* beta,
double epsilon)
{
using namespace at;
DISPATCH_DOUBLE_FLOAT_AND_HALF(input->scalar_type(), 0, "layer_norm_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
using output_t = at::Half;
HostApplyLayerNorm(
output->DATA_PTR<output_t>(),
mean->DATA_PTR<accscalar_t>(),
invvar->DATA_PTR<accscalar_t>(),
input->DATA_PTR<scalar_t_0>(),
n1,n2,
epsilon,
gamma != NULL ? gamma->DATA_PTR<output_t>() : NULL,
beta != NULL ? beta->DATA_PTR<output_t>() : NULL);
)
}
template<typename T, typename U, typename V>
void HostLayerNormGradient(
const V* dout,
const U* mean,
const U* invvar,
at::Tensor* input,
int n1,
int n2,
const V* gamma,
const V* beta,
double epsilon,
T* grad_input,
V* grad_gamma,
V* grad_beta
)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (gamma != NULL && beta != NULL) {
// compute grad_gamma(j) and grad_beta(j)
const int part_size = 16;
const dim3 threads2(32,4,1);
const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input->scalar_type()));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout,
input->DATA_PTR<T>(),
n1,n2,
mean,
invvar,
U(epsilon),
part_grad_gamma.DATA_PTR<U>(),
part_grad_beta.DATA_PTR<U>());
const dim3 threads3(32,8,1);
const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
const int nshared3 = threads3.x * threads3.y * sizeof(U);
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
part_grad_gamma.DATA_PTR<U>(),
part_grad_beta.DATA_PTR<U>(),
part_size,
n1,n2,
grad_gamma,
grad_beta);
}
// compute grad_input
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32,4,1);
int nshared =
threads1.y > 1 ?
threads1.y*threads1.x*sizeof(U) :
0;
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
dout,
input->DATA_PTR<T>(),
n1,n2,
mean,
invvar,
U(epsilon),
gamma,
grad_input);
}
void cuda_layer_norm_gradient(
at::Tensor* dout,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma,
at::Tensor* beta,
double epsilon,
at::Tensor* grad_input,
at::Tensor* grad_gamma,
at::Tensor* grad_beta)
{
using namespace at;
DISPATCH_FLOAT_AND_HALF(input->scalar_type(), 0, "cuComputeGradInput",
using accscalar_t = at::acc_type<scalar_t_0, true>;
using output_t = at::Half;
HostLayerNormGradient(
dout->DATA_PTR<output_t>(),
mean->DATA_PTR<accscalar_t>(),
invvar->DATA_PTR<accscalar_t>(),
input,
n1,n2,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma != NULL ? gamma->DATA_PTR<output_t>() : NULL,
gamma != NULL ? beta->DATA_PTR<output_t>() : NULL,
epsilon,
grad_input->DATA_PTR<scalar_t_0>(),
gamma != NULL ? grad_gamma->DATA_PTR<output_t>() : NULL,
gamma != NULL ? grad_beta->DATA_PTR<output_t>() : NULL);
)
}
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <ATen/ATen.h>
#include "compat.h"
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: \
{ \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes
(T *x,
T val,
int lanes=1,
bool share_result=false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = x[tid] + x[tid+i];
__syncthreads();
}
T final;
if(tid < 32)
{
if(blockSize >= 64)
final = x[tid] + x[tid+32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op
(T *x,
T val,
int lanes=1,
bool share_result=false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid+i]));
__syncthreads();
}
T final;
if(tid < 32)
{
if(blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid+32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
...@@ -13,9 +13,27 @@ ...@@ -13,9 +13,27 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
_LAYER_NORM = None
def import_layernorm(fp32_residual_connection):
global _LAYER_NORM
if not _LAYER_NORM:
if fp32_residual_connection:
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
else:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
_LAYER_NORM = LayerNorm
return _LAYER_NORM
from .distributed import * from .distributed import *
from .bert_model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage from .bert_model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage
from .realm_model import ICTBertModel from .realm_model import ICTBertModel
from .gpt2_model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage from .gpt2_model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage
from .utils import get_params_for_weight_decay_optimization from .utils import get_params_for_weight_decay_optimization
from .language_model import get_language_model from .language_model import get_language_model
...@@ -21,7 +21,7 @@ from megatron import get_args ...@@ -21,7 +21,7 @@ from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm from megatron.model import import_layernorm
from megatron.model.utils import openai_gelu, erf_gelu from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
...@@ -83,6 +83,7 @@ class BertLMHead(MegatronModule): ...@@ -83,6 +83,7 @@ class BertLMHead(MegatronModule):
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
LayerNorm = import_layernorm(args.fp32_residual_connection)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.gelu = torch.nn.functional.gelu self.gelu = torch.nn.functional.gelu
if args.openai_gelu: if args.openai_gelu:
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This code is copied fron NVIDIA apex:
https://github.com/NVIDIA/apex
with minor changes. """
import math
import torch
import numbers
from torch.nn.parameter import Parameter
from torch.nn import init
from torch.nn import functional as F
import importlib
global fused_layer_norm_cuda
fused_layer_norm_cuda = None
global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = None
class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
global fused_mix_prec_layer_norm_cuda
if fused_mix_prec_layer_norm_cuda is None:
fused_mix_prec_layer_norm_cuda = importlib.import_module("fused_mix_prec_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias = fused_mix_prec_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
class FusedLayerNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward(
input_, ctx.normalized_shape, ctx.eps)
ctx.save_for_backward(input_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, mean, invvar = ctx.saved_tensors
grad_input = None
grad_input = fused_layer_norm_cuda.backward(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
ctx.eps)
return grad_input, None, None
def fused_layer_norm_affine(input, normalized_shape, weight, bias, eps=1e-6):
return FusedLayerNormAffineFunction.apply(input, weight, bias, normalized_shape, eps)
def fused_layer_norm(input, normalized_shape, eps=1e-6):
return FusedLayerNormFunction.apply(input, normalized_shape, eps)
class MixedFusedLayerNorm(torch.nn.Module):
r"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ .
Currently only runs on cuda() tensors.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, Layer Normalization applies per-element scale and
bias with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1]
\times \ldots \times \text{normalized}\_\text{shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)
Examples::
>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = apex.normalization.FusedLayerNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = apex.normalization.FusedLayerNorm(10)
>>> # Activating the module
>>> output = m(input)
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
"""
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super(MixedFusedLayerNorm, self).__init__()
global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = importlib.import_module("fused_mix_prec_layer_norm_cuda")
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
if self.elementwise_affine:
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
if not input.is_cuda:
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps)
if self.elementwise_affine:
return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape,self.eps)
else:
return FusedLayerNormFunction.apply(input, self.normalized_shape, self.eps)
def extra_repr(self):
return '{normalized_shape}, eps={eps}, ' \
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
...@@ -21,9 +21,9 @@ import torch.nn.functional as F ...@@ -21,9 +21,9 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.mpu import LayerNorm
from megatron.module import MegatronModule from megatron.module import MegatronModule
from megatron.checkpointing import get_checkpoint_version from megatron.checkpointing import get_checkpoint_version
from megatron.model import import_layernorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu from megatron.model.utils import openai_gelu, erf_gelu
...@@ -404,6 +404,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -404,6 +404,7 @@ class ParallelTransformerLayer(MegatronModule):
= args.apply_residual_connection_post_layernorm = args.apply_residual_connection_post_layernorm
# Layernorm on the input data. # Layernorm on the input data.
LayerNorm = import_layernorm(args.fp32_residual_connection)
self.input_layernorm = LayerNorm( self.input_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
...@@ -500,6 +501,8 @@ class ParallelTransformer(MegatronModule): ...@@ -500,6 +501,8 @@ class ParallelTransformer(MegatronModule):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
args = get_args() args = get_args()
self.fp32_residual_connection = args.fp32_residual_connection
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.checkpoint_activations = args.checkpoint_activations self.checkpoint_activations = args.checkpoint_activations
self.checkpoint_num_layers = args.checkpoint_num_layers self.checkpoint_num_layers = args.checkpoint_num_layers
...@@ -520,6 +523,7 @@ class ParallelTransformer(MegatronModule): ...@@ -520,6 +523,7 @@ class ParallelTransformer(MegatronModule):
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
# Final layer norm before output. # Final layer norm before output.
LayerNorm = import_layernorm(args.fp32_residual_connection)
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
...@@ -564,7 +568,12 @@ class ParallelTransformer(MegatronModule): ...@@ -564,7 +568,12 @@ class ParallelTransformer(MegatronModule):
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
# Data format change to avoid explicit tranposes : [b s h] --> [s b h]. # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
hidden_states = hidden_states.transpose(0, 1).contiguous() # If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
hidden_states = hidden_states.transpose(0, 1).contiguous().float()
# Otherwise, leave it as is.
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
if self.checkpoint_activations: if self.checkpoint_activations:
hidden_states = self._checkpointed_forward(hidden_states, hidden_states = self._checkpointed_forward(hidden_states,
......
...@@ -19,8 +19,8 @@ import math ...@@ -19,8 +19,8 @@ import math
import torch import torch
from .transformer import LayerNorm from megatron import get_args
from megatron.model import import_layernorm
def init_method_normal(sigma): def init_method_normal(sigma):
"""Init method based on N(0, sigma).""" """Init method based on N(0, sigma)."""
...@@ -65,6 +65,10 @@ def get_params_for_weight_decay_optimization(module): ...@@ -65,6 +65,10 @@ def get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups. """Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will. Layernorms and baises will have no weight decay but the rest will.
""" """
args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection)
weight_decay_params = {'params': []} weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0} no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules(): for module_ in module.modules():
......
...@@ -41,7 +41,6 @@ from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_mod ...@@ -41,7 +41,6 @@ from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_mod
from .initialize import initialize_model_parallel from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized from .initialize import model_parallel_is_initialized
from .layers import LayerNorm
from .layers import ColumnParallelLinear from .layers import ColumnParallelLinear
from .layers import RowParallelLinear from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding from .layers import VocabParallelEmbedding
......
...@@ -25,16 +25,6 @@ import torch.nn.functional as F ...@@ -25,16 +25,6 @@ import torch.nn.functional as F
import torch.nn.init as init import torch.nn.init as init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
# Try to use FusedLayerNorm from Apex - this will trigger an error.
_ = LayerNorm(8, eps=1e-5)
except Exception as e:
print('WARNING: APEX is not installed, using torch.nn.LayerNorm '
'instead of apex.normalization.FusedLayerNorm!')
from torch.nn import LayerNorm
from .initialize import get_tensor_model_parallel_rank from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size
from .mappings import copy_to_tensor_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
......
...@@ -333,16 +333,19 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward) ...@@ -333,16 +333,19 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
tensor_recv_prev = None tensor_recv_prev = None
tensor_recv_next = None tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
dtype = args.params_dtype
if args.fp32_residual_connection:
dtype = torch.float
if recv_forward: if recv_forward:
tensor_recv_prev = torch.empty(tensor_shape, tensor_recv_prev = torch.empty(tensor_shape,
requires_grad=True, requires_grad=True,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=args.params_dtype) dtype=dtype)
if recv_backward: if recv_backward:
tensor_recv_next = torch.empty(tensor_shape, tensor_recv_next = torch.empty(tensor_shape,
requires_grad=True, requires_grad=True,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=args.params_dtype) dtype=dtype)
# Send tensors in both the forward and backward directions as appropriate. # Send tensors in both the forward and backward directions as appropriate.
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev, torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment