Commit a5acbf53 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Merge branch 'main' into main_retriver_merge_ict_eval

parents 40565390 a6e00d97
...@@ -19,7 +19,6 @@ import argparse ...@@ -19,7 +19,6 @@ import argparse
import os import os
import torch import torch
from megatron import fused_kernels
def parse_args(extra_args_provider=None, defaults={}, def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False): ignore_unknown_args=False):
...@@ -134,8 +133,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -134,8 +133,7 @@ def parse_args(extra_args_provider=None, defaults={},
if args.bf16: if args.bf16:
assert not args.fp16 assert not args.fp16
args.params_dtype = torch.bfloat16 args.params_dtype = torch.bfloat16
# No fusion is support for bfloat for now # Jitting fusion is not supported for bfloat for now
assert not args.masked_softmax_fusion
assert not args.bias_gelu_fusion assert not args.bias_gelu_fusion
assert not args.bias_dropout_fusion assert not args.bias_dropout_fusion
...@@ -227,31 +225,6 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -227,31 +225,6 @@ def parse_args(extra_args_provider=None, defaults={},
'for distribute-checkpointed-activations to work you '\ 'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations' 'need to enable checkpoint-activations'
# custom kernel constraints check
seq_len = args.seq_length
attn_batch_size = \
(args.num_attention_heads / args.tensor_model_parallel_size) * \
args.micro_batch_size
# constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \
seq_len % 4 == 0 and attn_batch_size % 4 == 0
if not (args.fp16 and custom_kernel_constraint and args.masked_softmax_fusion):
print('WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default back to unfused'
' kernel invocations.')
# Load scaled_masked_softmax_fusion_kernels
if args.masked_softmax_fusion:
fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()
fused_kernels.load_scaled_masked_softmax_fusion_kernel()
# Load mixed precision fused layer norm.
if args.fp32_residual_connection:
fused_kernels.load_fused_mix_prec_layer_norm_kernel()
_print_args(args) _print_args(args)
return args return args
......
...@@ -13,114 +13,97 @@ ...@@ -13,114 +13,97 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import pathlib import pathlib
import subprocess import subprocess
import os
from torch.utils import cpp_extension from torch.utils import cpp_extension
# Setting this param to a list has a problem of generating # Setting this param to a list has a problem of generating different
# different compilation commands (with diferent order of architectures) # compilation commands (with diferent order of architectures) and
# and leading to recompilation of fused kernels. # leading to recompilation of fused kernels. Set it to empty string
# set it to empty string to avoid recompilation # to avoid recompilation and assign arch flags explicity in
# and assign arch flags explicity in extra_cuda_cflags below # extra_cuda_cflags below
os.environ["TORCH_CUDA_ARCH_LIST"] = "" os.environ["TORCH_CUDA_ARCH_LIST"] = ""
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def create_build_dir(buildpath):
try:
os.mkdir(buildpath)
except OSError:
if not os.path.isdir(buildpath):
print(f"Creation of the build directory {buildpath} failed")
def load_scaled_upper_triang_masked_softmax_fusion_kernel(): def load(args):
# Check, if CUDA11 is installed for compute capability 8.0 # Check if cuda 11 is installed for compute capability 8.0
cc_flag = [] cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) _, bare_metal_major, _ = _get_cuda_bare_metal_version(
cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11: if int(bare_metal_major) >= 11:
cc_flag.append('-gencode') cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80') cc_flag.append('arch=compute_80,code=sm_80')
# Build path
srcpath = pathlib.Path(__file__).parent.absolute() srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / 'build' buildpath = srcpath / 'build'
_create_build_dir(buildpath)
create_build_dir(buildpath)
# Helper function to build the kernels.
scaled_upper_triang_masked_softmax_cuda = cpp_extension.load( def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
name='scaled_upper_triang_masked_softmax_cuda', return cpp_extension.load(
name=name,
sources=sources,
build_directory=buildpath,
extra_cflags=['-O3',],
extra_cuda_cflags=['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'--use_fast_math'] + extra_cuda_flags + cc_flag,
verbose=(args.rank == 0)
)
# ==============
# Fused softmax.
# ==============
if args.masked_softmax_fusion:
extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda']
# Upper triangular softmax.
sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp', sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'], srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu']
build_directory=buildpath, scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper(
extra_cflags=['-O3',], "scaled_upper_triang_masked_softmax_cuda",
extra_cuda_cflags=['-O3', sources, extra_cuda_flags)
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + cc_flag)
def load_scaled_masked_softmax_fusion_kernel():
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
srcpath = pathlib.Path(__file__).parent.absolute() # Masked softmax.
buildpath = srcpath / 'build' sources=[srcpath / 'scaled_masked_softmax.cpp',
srcpath / 'scaled_masked_softmax_cuda.cu']
scaled_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_masked_softmax_cuda", sources, extra_cuda_flags)
create_build_dir(buildpath) # =================================
# Mixed precision fused layer norm.
# =================================
scaled_upper_triang_masked_softmax_cuda = cpp_extension.load( extra_cuda_flags = ['-maxrregcount=50']
name='scaled_masked_softmax_cuda', sources=[srcpath / 'layer_norm_cuda.cpp',
sources=[srcpath / 'scaled_masked_softmax.cpp', srcpath / 'layer_norm_cuda_kernel.cu']
srcpath / 'scaled_masked_softmax_cuda.cu'], fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper(
build_directory=buildpath, "fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags)
extra_cflags=['-O3',],
extra_cuda_cflags=['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + cc_flag)
def load_fused_mix_prec_layer_norm_kernel(): def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
# Check, if CUDA11 is installed for compute capability 8.0 return raw_output, bare_metal_major, bare_metal_minor
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / 'build'
create_build_dir(buildpath) def _create_build_dir(buildpath):
try:
fused_mix_prec_layer_norm_cuda = cpp_extension.load( os.mkdir(buildpath)
name='fused_mix_prec_layer_norm_cuda', except OSError:
sources=[srcpath / 'layer_norm_cuda.cpp', if not os.path.isdir(buildpath):
srcpath / 'layer_norm_cuda_kernel.cu'], print(f"Creation of the build directory {buildpath} failed")
build_directory=buildpath,
extra_cflags=['-O3'],
extra_cuda_cflags=['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-maxrregcount=50',
'--use_fast_math'] + cc_flag)
...@@ -24,16 +24,12 @@ ...@@ -24,16 +24,12 @@
#include "compat.h" #include "compat.h"
namespace { namespace {
void compute_n1_n2( void compute_n1_n2(
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
int& n1, int& n1,
int& n2) int& n2) {
{
int idiff = input.ndimension() - normalized_shape.size(); int idiff = input.ndimension() - normalized_shape.size();
n2 = 1; n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) { for (int i = 0; i < (int)normalized_shape.size(); ++i) {
...@@ -47,11 +43,7 @@ void compute_n1_n2( ...@@ -47,11 +43,7 @@ void compute_n1_n2(
} }
void check_args( void check_args(
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma, at::Tensor gamma,
at::Tensor beta at::Tensor beta
) )
...@@ -62,11 +54,7 @@ void check_args( ...@@ -62,11 +54,7 @@ void check_args(
void check_args( void check_args(
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
int& n1, int& n1,
int& n2 int& n2
) )
...@@ -102,11 +90,7 @@ void check_args( ...@@ -102,11 +90,7 @@ void check_args(
void check_args( void check_args(
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma, at::Tensor gamma,
at::Tensor beta, at::Tensor beta,
int& n1, int& n1,
...@@ -125,60 +109,42 @@ void cuda_layer_norm( ...@@ -125,60 +109,42 @@ void cuda_layer_norm(
at::Tensor* input, at::Tensor* input,
int n1, int n1,
int n2, int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma, at::Tensor* gamma,
at::Tensor* beta, at::Tensor* beta,
double epsilon); double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
double epsilon) {
CHECK_INPUT(input);
int n1,n2;
check_args(input,normalized_shape,n1,n2);
at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,NULL,NULL,epsilon);
return {output, mean, invvar};
}
std::vector<at::Tensor> layer_norm_affine( std::vector<at::Tensor> layer_norm_affine(
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma, at::Tensor gamma,
at::Tensor beta, at::Tensor beta,
double epsilon) { double epsilon) {
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(gamma); CHECK_INPUT(gamma);
CHECK_INPUT(beta); CHECK_INPUT(beta);
int n1,n2; int n1, n2;
check_args(input,normalized_shape,gamma,beta,n1,n2); check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor output = at::empty_like(input, input.options().dtype(at::ScalarType::Half));
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type())); at::Tensor output = at::empty_like(
input, gamma.options().dtype(gamma.scalar_type()));
at::Tensor mean = at::empty(
{n1}, input.options().dtype(at::ScalarType::Float));
at::Tensor invvar = at::empty_like(mean); at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon); cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2,
normalized_shape, &gamma, &beta, epsilon);
return {output, mean, invvar}; return {output, mean, invvar};
} }
void cuda_layer_norm_gradient( void cuda_layer_norm_gradient(
at::Tensor* dout, at::Tensor* dout,
at::Tensor* mean, at::Tensor* mean,
...@@ -186,11 +152,7 @@ void cuda_layer_norm_gradient( ...@@ -186,11 +152,7 @@ void cuda_layer_norm_gradient(
at::Tensor* input, at::Tensor* input,
int n1, int n1,
int n2, int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma, at::Tensor* gamma,
at::Tensor* beta, at::Tensor* beta,
double epsilon, double epsilon,
...@@ -199,62 +161,41 @@ void cuda_layer_norm_gradient( ...@@ -199,62 +161,41 @@ void cuda_layer_norm_gradient(
at::Tensor* grad_beta at::Tensor* grad_beta
); );
at::Tensor layer_norm_gradient(
at::Tensor dout,
at::Tensor mean,
at::Tensor invvar,
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
int n1,n2;
check_args(input,normalized_shape,n1,n2);
at::Tensor grad_input = at::empty_like(input);
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
normalized_shape,NULL,NULL,epsilon,
&grad_input,NULL,NULL);
return grad_input;
}
std::vector<at::Tensor> layer_norm_gradient_affine( std::vector<at::Tensor> layer_norm_gradient_affine(
at::Tensor dout, at::Tensor dout,
at::Tensor mean, at::Tensor mean,
at::Tensor invvar, at::Tensor invvar,
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma, at::Tensor gamma,
at::Tensor beta, at::Tensor beta,
double epsilon) { double epsilon) {
CHECK_INPUT(dout); CHECK_INPUT(dout);
CHECK_INPUT(mean); CHECK_INPUT(mean);
CHECK_INPUT(invvar); CHECK_INPUT(invvar);
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(gamma); CHECK_INPUT(gamma);
CHECK_INPUT(beta); CHECK_INPUT(beta);
int n1,n2; int n1, n2;
check_args(input,normalized_shape,gamma,beta,n1,n2); check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor grad_input = at::empty_like(input); at::Tensor grad_input = at::empty_like(input);
at::Tensor grad_gamma = at::empty_like(gamma); at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta); at::Tensor grad_beta = at::empty_like(beta);
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon, cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
&grad_input,&grad_gamma,&grad_beta); normalized_shape, &gamma, &beta, epsilon,
&grad_input, &grad_gamma, &grad_beta);
return {grad_input, grad_gamma, grad_beta}; return {grad_input, grad_gamma, grad_beta};
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); m.def("forward_affine", &layer_norm_affine,
m.def("forward", &layer_norm, "LayerNorm forward (CUDA)"); "LayerNorm forward (CUDA)");
m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)"); m.def("backward_affine", &layer_norm_gradient_affine,
m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)"); "LayerNorm backward (CUDA)");
} }
...@@ -285,15 +285,6 @@ struct SharedMemory <float> ...@@ -285,15 +285,6 @@ struct SharedMemory <float>
} }
}; };
template <>
struct SharedMemory <double>
{
__device__ double *getPointer()
{
extern __shared__ double s_double[];
return s_double;
}
};
} }
template<typename T, typename U, typename V> __global__ template<typename T, typename U, typename V> __global__
...@@ -656,6 +647,9 @@ void cuComputeGradInput( ...@@ -656,6 +647,9 @@ void cuComputeGradInput(
} }
} }
template<typename T, typename U, typename V> template<typename T, typename U, typename V>
void HostApplyLayerNorm( void HostApplyLayerNorm(
V* output, V* output,
...@@ -671,7 +665,8 @@ void HostApplyLayerNorm( ...@@ -671,7 +665,8 @@ void HostApplyLayerNorm(
{ {
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
const dim3 threads(32,4,1); const dim3 threads(32,4,1);
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared = int nshared =
threads.y > 1 ? threads.y > 1 ?
...@@ -687,6 +682,7 @@ void HostApplyLayerNorm( ...@@ -687,6 +682,7 @@ void HostApplyLayerNorm(
gamma,beta); gamma,beta);
} }
void cuda_layer_norm( void cuda_layer_norm(
at::Tensor* output, at::Tensor* output,
at::Tensor* mean, at::Tensor* mean,
...@@ -704,21 +700,21 @@ void cuda_layer_norm( ...@@ -704,21 +700,21 @@ void cuda_layer_norm(
double epsilon) double epsilon)
{ {
using namespace at; using namespace at;
DISPATCH_DOUBLE_FLOAT_AND_HALF(input->scalar_type(), 0, "layer_norm_cuda_kernel", DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
using accscalar_t = at::acc_type<scalar_t_0, true>; input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel",
using output_t = at::Half;
HostApplyLayerNorm( HostApplyLayerNorm(
output->DATA_PTR<output_t>(), output->DATA_PTR<scalar_t_out>(),
mean->DATA_PTR<accscalar_t>(), mean->DATA_PTR<float>(),
invvar->DATA_PTR<accscalar_t>(), invvar->DATA_PTR<float>(),
input->DATA_PTR<scalar_t_0>(), input->DATA_PTR<scalar_t_in>(),
n1,n2, n1,n2,
epsilon, epsilon,
gamma != NULL ? gamma->DATA_PTR<output_t>() : NULL, gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
beta != NULL ? beta->DATA_PTR<output_t>() : NULL); beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);
) )
} }
template<typename T, typename U, typename V> template<typename T, typename U, typename V>
void HostLayerNormGradient( void HostLayerNormGradient(
const V* dout, const V* dout,
...@@ -742,10 +738,12 @@ void HostLayerNormGradient( ...@@ -742,10 +738,12 @@ void HostLayerNormGradient(
const int part_size = 16; const int part_size = 16;
const dim3 threads2(32,4,1); const dim3 threads2(32,4,1);
const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1); const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y *
(threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U); const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input->scalar_type())); at::Tensor part_grad_gamma = at::empty(
{part_size,n2}, input->options().dtype(at::ScalarType::Float));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>( cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout, dout,
...@@ -770,7 +768,8 @@ void HostLayerNormGradient( ...@@ -770,7 +768,8 @@ void HostLayerNormGradient(
} }
// compute grad_input // compute grad_input
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32,4,1); const dim3 threads1(32,4,1);
int nshared = int nshared =
...@@ -788,6 +787,7 @@ void HostLayerNormGradient( ...@@ -788,6 +787,7 @@ void HostLayerNormGradient(
grad_input); grad_input);
} }
void cuda_layer_norm_gradient( void cuda_layer_norm_gradient(
at::Tensor* dout, at::Tensor* dout,
at::Tensor* mean, at::Tensor* mean,
...@@ -808,22 +808,22 @@ void cuda_layer_norm_gradient( ...@@ -808,22 +808,22 @@ void cuda_layer_norm_gradient(
at::Tensor* grad_beta) at::Tensor* grad_beta)
{ {
using namespace at; using namespace at;
DISPATCH_FLOAT_AND_HALF(input->scalar_type(), 0, "cuComputeGradInput", DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
using accscalar_t = at::acc_type<scalar_t_0, true>; input->scalar_type(), gamma->scalar_type(),
using output_t = at::Half; "cuda_layer_norm_gradient_kernel",
HostLayerNormGradient( HostLayerNormGradient(
dout->DATA_PTR<output_t>(), dout->DATA_PTR<scalar_t_out>(),
mean->DATA_PTR<accscalar_t>(), mean->DATA_PTR<float>(),
invvar->DATA_PTR<accscalar_t>(), invvar->DATA_PTR<float>(),
input, input,
n1,n2, n1,n2,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input. // if gamma Tensor is NULL on input.
gamma != NULL ? gamma->DATA_PTR<output_t>() : NULL, gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? beta->DATA_PTR<output_t>() : NULL, gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
epsilon, epsilon,
grad_input->DATA_PTR<scalar_t_0>(), grad_input->DATA_PTR<scalar_t_in>(),
gamma != NULL ? grad_gamma->DATA_PTR<output_t>() : NULL, gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? grad_beta->DATA_PTR<output_t>() : NULL); gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
) )
} }
...@@ -37,8 +37,9 @@ torch::Tensor fwd( ...@@ -37,8 +37,9 @@ torch::Tensor fwd(
torch::Tensor const& mask, torch::Tensor const& mask,
float scale_factor) { float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM(input.scalar_type() == at::ScalarType::Half, AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
"Only HALF is supported"); (input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
return fwd_cuda(input, mask, scale_factor); return fwd_cuda(input, mask, scale_factor);
...@@ -52,10 +53,12 @@ torch::Tensor bwd( ...@@ -52,10 +53,12 @@ torch::Tensor bwd(
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM(output_grads.scalar_type() == at::ScalarType::Half, AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
"Only HALF is supported"); (output_grads.scalar_type() == at::ScalarType::BFloat16),
AT_ASSERTM(softmax_results.scalar_type() == at::ScalarType::Half, "Only fp16 and bf16 are supported");
"Only HALF is supported"); AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor); return bwd_cuda(output_grads, softmax_results, scale_factor);
} }
......
...@@ -26,6 +26,27 @@ ...@@ -26,6 +26,27 @@
namespace { namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
int log2_ceil(int value) { int log2_ceil(int value) {
int log2_value = 0; int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value; while ((1 << log2_value) < value) ++log2_value;
...@@ -90,13 +111,14 @@ __global__ void scaled_masked_softmax_warp_forward( ...@@ -90,13 +111,14 @@ __global__ void scaled_masked_softmax_warp_forward(
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches) // gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
int pad_first_batch = 0; int pad_first_batch = 0;
if (pad_batches != 1) { // bert style if (pad_batches != 1) { // bert style
pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;
} else { // gpt2 style } else { // gpt2 style
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
} }
...@@ -110,29 +132,40 @@ __global__ void scaled_masked_softmax_warp_forward( ...@@ -110,29 +132,40 @@ __global__ void scaled_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch // there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
src += first_batch * element_count + local_idx; src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * element_count + local_idx; dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
mask += pad_first_batch * element_count + local_idx; mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory // load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS]; acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count; int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
int itr_idx = i*element_count+it*WARP_SIZE;
if (element_index < batch_element_count) { if (element_index < batch_element_count) {
if (mask[itr_idx] != 1) { int itr_idx = i*element_count+it*WARP_SIZE;
elements[i][it] = (acc_t)src[itr_idx] * scale; copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
} else { copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
elements[i][it] = -10000.0;
} #pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (temp_mask[element] != 1) {
elements[i][it + element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -10000.0;
}
}
} else { } else {
elements[i][it] = -std::numeric_limits<acc_t>::infinity(); #pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
} }
} }
} }
...@@ -161,15 +194,20 @@ __global__ void scaled_masked_softmax_warp_forward( ...@@ -161,15 +194,20 @@ __global__ void scaled_masked_softmax_warp_forward(
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum); warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result // store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches) if (i >= local_batches)
break; break;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) { if (element_index < element_count) {
dst[i*element_count+it*WARP_SIZE] = (output_t)(elements[i][it] / sum[i]); #pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
} else { } else {
break; break;
} }
...@@ -192,6 +230,7 @@ __global__ void scaled_masked_softmax_warp_backward( ...@@ -192,6 +230,7 @@ __global__ void scaled_masked_softmax_warp_backward(
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches) // gridDim/blockIdx = (seq_len, attn_heads, batches)
...@@ -207,36 +246,36 @@ __global__ void scaled_masked_softmax_warp_backward( ...@@ -207,36 +246,36 @@ __global__ void scaled_masked_softmax_warp_backward(
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
// the first element to process by the current thread // the first element to process by the current thread
int thread_offset = first_batch * element_count + local_idx; int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset; grad += thread_offset;
output += thread_offset; output += thread_offset;
gradInput += thread_offset; gradInput += thread_offset;
// load data from global memory // load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count; int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) { if (element_index < batch_element_count) {
output_reg[i][it] = output[i*element_count+it*WARP_SIZE]; copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count + it * WARP_SIZE);
} else { copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count + it * WARP_SIZE);
output_reg[i][it] = acc_t(0);
} #pragma unroll
} for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
output_reg[i][it + element] = (acc_t)temp_output[element];
#pragma unroll }
for (int it = 0; it < WARP_ITERATIONS; ++it) { #pragma unroll
int element_index = local_idx + it * WARP_SIZE; for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index < batch_element_count) { grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
grad_reg[i][it] = (acc_t)grad[i*element_count+it*WARP_SIZE] * output_reg[i][it]; }
} else { }
grad_reg[i][it] = acc_t(0);
}
} }
} }
...@@ -257,11 +296,16 @@ __global__ void scaled_masked_softmax_warp_backward( ...@@ -257,11 +296,16 @@ __global__ void scaled_masked_softmax_warp_backward(
if (i >= local_batches) if (i >= local_batches)
break; break;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) { if (element_index < element_count) {
// compute gradients // compute gradients
gradInput[i*element_count+it*WARP_SIZE] = (output_t)(scale * (grad_reg[i][it] - output_reg[i][it] * sum[i])); output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);
} }
} }
} }
...@@ -299,8 +343,8 @@ void dispatch_scaled_masked_softmax_forward( ...@@ -299,8 +343,8 @@ void dispatch_scaled_masked_softmax_forward(
constexpr int threads_per_block = 128; constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size); int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR // Launch code would be more elegant if C++ supported FOR CONSTEXPR
...@@ -388,7 +432,7 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -388,7 +432,7 @@ void dispatch_scaled_masked_softmax_backward(
constexpr int threads_per_block = 128; constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size); int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
int blocks = batch_count/batches_per_block; int blocks = batch_count/batches_per_block;
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR // Launch code would be more elegant if C++ supported FOR CONSTEXPR
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "scaled_masked_softmax.h" #include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
...@@ -56,16 +56,20 @@ torch::Tensor fwd_cuda( ...@@ -56,16 +56,20 @@ torch::Tensor fwd_cuda(
void* mask_ptr = static_cast<void*>(mask.data_ptr()); void* mask_ptr = static_cast<void*>(mask.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
dispatch_scaled_masked_softmax_forward<half, half, float>( DISPATCH_HALF_AND_BFLOAT(
reinterpret_cast<half*>(softmax_results_ptr), input.scalar_type(),
reinterpret_cast<const half*>(input_ptr), "dispatch_scaled_masked_softmax_forward",
reinterpret_cast<const uint8_t*>(mask_ptr), dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
scale_factor, reinterpret_cast<scalar_t*>(softmax_results_ptr),
query_seq_len, reinterpret_cast<const scalar_t*>(input_ptr),
key_seq_len, reinterpret_cast<const uint8_t*>(mask_ptr),
batches, scale_factor,
attn_heads, query_seq_len,
pad_batches); key_seq_len,
batches,
attn_heads,
pad_batches);
);
return softmax_results; return softmax_results;
} }
...@@ -86,15 +90,19 @@ torch::Tensor bwd_cuda( ...@@ -86,15 +90,19 @@ torch::Tensor bwd_cuda(
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr()); void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad //Softmax Grad
dispatch_scaled_masked_softmax_backward<half, half, float>( DISPATCH_HALF_AND_BFLOAT(
reinterpret_cast<half*>(output_grads_ptr), output_grads_.scalar_type(),
reinterpret_cast<half*>(output_grads_ptr), "dispatch_scaled_masked_softmax_backward",
reinterpret_cast<half const*>(softmax_results.data_ptr()), dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
scale_factor, reinterpret_cast<scalar_t*>(output_grads_ptr),
query_seq_len, reinterpret_cast<scalar_t*>(output_grads_ptr),
key_seq_len, reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
batches, scale_factor,
attn_heads); query_seq_len,
key_seq_len,
batches,
attn_heads);
);
//backward pass is completely in-place //backward pass is completely in-place
return output_grads; return output_grads;
......
...@@ -33,8 +33,9 @@ torch::Tensor bwd_cuda( ...@@ -33,8 +33,9 @@ torch::Tensor bwd_cuda(
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.scalar_type() == at::ScalarType::Half, AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
"Only HALF is supported"); (input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor); return fwd_cuda(input, scale_factor);
} }
...@@ -47,10 +48,12 @@ torch::Tensor bwd( ...@@ -47,10 +48,12 @@ torch::Tensor bwd(
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.scalar_type() == at::ScalarType::Half, AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
"Only HALF is supported"); (output_grads.scalar_type() == at::ScalarType::BFloat16),
AT_ASSERTM(softmax_results.scalar_type() == at::ScalarType::Half, "Only fp16 and bf16 are supported");
"Only HALF is supported"); AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor); return bwd_cuda(output_grads, softmax_results, scale_factor);
} }
...@@ -61,7 +64,7 @@ torch::Tensor bwd( ...@@ -61,7 +64,7 @@ torch::Tensor bwd(
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", m.def("forward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward."); "Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward", m.def("backward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
......
...@@ -21,11 +21,47 @@ ...@@ -21,11 +21,47 @@
#include <cfloat> #include <cfloat>
#include <limits> #include <limits>
#include <stdint.h> #include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
namespace { namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst);
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16 *dst) { *dst = 0.0; }
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { *dst = 0.0; }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
int log2_ceil(int value) { int log2_ceil(int value) {
int log2_value = 0; int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value; while ((1 << log2_value) < value) ++log2_value;
...@@ -73,7 +109,7 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) { ...@@ -73,7 +109,7 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
* Extended softmax (from native aten pytorch) with following additional features * Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling * 1) input scaling
* 2) Implicit time (diagonal masking) * 2) Implicit time (diagonal masking)
*/ */
template <typename input_t, typename output_t, typename acc_t, int log2_elements> template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_forward( __global__ void scaled_upper_triang_masked_softmax_warp_forward(
output_t *dst, output_t *dst,
...@@ -89,10 +125,11 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( ...@@ -89,10 +125,11 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1; int local_seq = blockIdx.x + 1;
int warp_iteration_limit = (local_seq + WARP_SIZE - 1)/WARP_SIZE; int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how // micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP. // many batches have to computed within this WARP.
...@@ -103,22 +140,36 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( ...@@ -103,22 +140,36 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch // there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
src += first_batch * stride + local_idx; src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * stride + local_idx; dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory // load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS]; acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq; int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) { if (element_index < batch_element_count) {
elements[i][it] = (acc_t)src[i*element_count*stride+it*WARP_SIZE] * scale; copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + i*element_count*stride + it*WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if ((element_index + element) < batch_element_count) {
elements[i][it+element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
} else { } else {
elements[i][it] = -std::numeric_limits<acc_t>::infinity(); #pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
} }
} }
} }
...@@ -140,26 +191,37 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( ...@@ -140,26 +191,37 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < WARP_ITERATIONS; ++it) {
if (it < warp_iteration_limit) { if (it < warp_iteration_limit) {
elements[i][it] = std::exp((elements[i][it] - max_value[i])); elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it]; sum[i] += elements[i][it];
} }
} }
} }
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum); warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result // store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches) if (i >= local_batches)
break; break;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < local_seq) { if (element_index < local_seq) {
dst[i*element_count*stride+it*WARP_SIZE] = (output_t)(elements[i][it] / sum[i]);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < local_seq) {
out[element] = elements[i][it + element] / sum[i];
} else {
out[element] = 0;
}
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out);
} else if (element_index < element_count) { } else if (element_index < element_count) {
dst[i*element_count*stride+it*WARP_SIZE] = 0; copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE);
} else { } else {
break; break;
} }
...@@ -183,6 +245,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( ...@@ -183,6 +245,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1; int local_seq = blockIdx.x + 1;
...@@ -197,37 +260,41 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( ...@@ -197,37 +260,41 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
// the first element to process by the current thread // the first element to process by the current thread
int thread_offset = first_batch * stride + local_idx; int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset; grad += thread_offset;
output += thread_offset; output += thread_offset;
gradInput += thread_offset; gradInput += thread_offset;
// load data from global memory // load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq; int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) { if (element_index < batch_element_count) {
output_reg[i][it] = output[i*element_count*stride+it*WARP_SIZE]; copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
} else { copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count * stride + it * WARP_SIZE);
output_reg[i][it] = acc_t(0);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
output_reg[i][it + element] = (acc_t)temp_output[element];
}
}
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
}
}
} }
} }
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
grad_reg[i][it] = (acc_t)grad[i*element_count*stride+it*WARP_SIZE] * output_reg[i][it];
} else {
grad_reg[i][it] = acc_t(0);
}
}
} }
acc_t sum[WARP_BATCH]; acc_t sum[WARP_BATCH];
...@@ -247,11 +314,16 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( ...@@ -247,11 +314,16 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
if (i >= local_batches) if (i >= local_batches)
break; break;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) { if (element_index < element_count) {
// compute gradients // compute gradients
gradInput[i*element_count*stride+it*WARP_SIZE] = (output_t)(scale * (grad_reg[i][it] - output_reg[i][it] * sum[i])); output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count * stride + it * WARP_SIZE, out);
} }
} }
} }
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h" #include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
...@@ -46,15 +46,20 @@ torch::Tensor fwd_cuda( ...@@ -46,15 +46,20 @@ torch::Tensor fwd_cuda(
void* input_ptr = static_cast<void*>(input.data_ptr()); void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
dispatch_scaled_upper_triang_masked_softmax_forward<half, half, float>( DISPATCH_HALF_AND_BFLOAT(
reinterpret_cast<half*>(softmax_results_ptr), input.scalar_type(),
reinterpret_cast<const half*>(input_ptr), "dispatch_scaled_upper_triang_masked_softmax_forward",
scale_factor, dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
seq_len, reinterpret_cast<scalar_t*>(softmax_results_ptr),
seq_len, reinterpret_cast<const scalar_t*>(input_ptr),
attn_batches); scale_factor,
seq_len,
seq_len,
attn_batches);
);
return softmax_results; return softmax_results;
} }
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_, torch::Tensor const& output_grads_,
...@@ -72,14 +77,18 @@ torch::Tensor bwd_cuda( ...@@ -72,14 +77,18 @@ torch::Tensor bwd_cuda(
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr()); void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad //Softmax Grad
dispatch_scaled_upper_triang_masked_softmax_backward<half, half, float>( DISPATCH_HALF_AND_BFLOAT(
reinterpret_cast<half*>(output_grads_ptr), output_grads_.scalar_type(),
reinterpret_cast<half*>(output_grads_ptr), "dispatch_scaled_upper_triang_masked_softmax_backward",
reinterpret_cast<half const*>(softmax_results.data_ptr()), dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
scale_factor, reinterpret_cast<scalar_t*>(output_grads_ptr),
seq_len, reinterpret_cast<scalar_t*>(output_grads_ptr),
seq_len, reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
attn_batches); scale_factor,
seq_len,
seq_len,
attn_batches);
);
//backward pass is completely in-place //backward pass is completely in-place
return output_grads; return output_grads;
......
...@@ -14,214 +14,78 @@ ...@@ -14,214 +14,78 @@
* limitations under the License. * limitations under the License.
*/ */
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include "compat.h" #include "compat.h"
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch(TYPE) \ switch(TYPE) \
{ \ { \
case at::ScalarType::Float: \ case at::ScalarType::Half: \
{ \ { \
using scalar_t_##LEVEL = float; \ using scalar_t = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Half: \ case at::ScalarType::BFloat16: \
{ \ { \
using scalar_t_##LEVEL = at::Half; \ using scalar_t = at::BFloat16; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
default: \ default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} }
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch(TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: \
{ \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes
(T *x,
T val,
int lanes=1,
bool share_result=false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = x[tid] + x[tid+i];
__syncthreads();
}
T final;
if(tid < 32)
{
if(blockSize >= 64)
final = x[tid] + x[tid+32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op
(T *x,
T val,
int lanes=1,
bool share_result=false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid+i]));
__syncthreads();
}
T final;
if(tid < 32)
{
if(blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid+32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
...@@ -17,16 +17,20 @@ ...@@ -17,16 +17,20 @@
import random import random
import os import os
import time
import numpy as np import numpy as np
import torch import torch
from megatron import fused_kernels
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
from megatron import get_args from megatron import get_args
from megatron import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron import mpu from megatron import mpu
from megatron.global_vars import set_global_variables from megatron.global_vars import set_global_variables
from megatron.mpu import set_tensor_model_parallel_rank, set_tensor_model_parallel_world_size from megatron.mpu import (set_tensor_model_parallel_rank,
set_tensor_model_parallel_world_size)
def initialize_megatron(extra_args_provider=None, args_defaults={}, def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False, allow_no_cuda=False): ignore_unknown_args=False, allow_no_cuda=False):
...@@ -37,8 +41,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -37,8 +41,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
what you are doing. what you are doing.
Returns a function to finalize distributed env initialization Returns a function to finalize distributed env initialization
(optionally, only when args.lazy_mpu_init == True) (optionally, only when args.lazy_mpu_init == True)
"""
"""
if not allow_no_cuda: if not allow_no_cuda:
# Make sure cuda is available. # Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.' assert torch.cuda.is_available(), 'Megatron requires CUDA.'
...@@ -66,7 +69,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -66,7 +69,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# delayed initialization of DDP-related stuff # delayed initialization of DDP-related stuff
# We only set basic DDP globals # We only set basic DDP globals
set_tensor_model_parallel_world_size(args.tensor_model_parallel_size) set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
# and return function for external DDP manager to call when it has DDP initialized # and return function for external DDP manager
# to call when it has DDP initialized
set_tensor_model_parallel_rank(args.rank) set_tensor_model_parallel_rank(args.rank)
return finish_mpu_init return finish_mpu_init
else: else:
...@@ -79,16 +83,71 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -79,16 +83,71 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Autoresume. # Autoresume.
_init_autoresume() _init_autoresume()
# Compile dataset C++ code. # Compile dependencies.
if torch.distributed.get_rank() == 0: _compile_dependencies()
from megatron.data.dataset_utils import compile_helper
compile_helper()
# Simple barrier
torch.distributed.barrier()
# No continuation function # No continuation function
return None return None
def _compile_dependencies():
args = get_args()
# =========================
# Compile dataset C++ code.
# =========================
# TODO: move this to ninja
if torch.distributed.get_rank() == 0:
start_time = time.time()
print('> compiling dataset index builder ...')
from megatron.data.dataset_utils import compile_helper
compile_helper()
print('>>> done with dataset index builder. Compilation time: {:.3f} '
'seconds'.format(time.time() - start_time), flush=True)
# ==================
# Load fused kernels
# ==================
# Custom kernel constraints check.
seq_len = args.seq_length
attn_batch_size = \
(args.num_attention_heads / args.tensor_model_parallel_size) * \
args.micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \
seq_len % 4 == 0 and attn_batch_size % 4 == 0
# Print a warning.
if not ((args.fp16 or args.bf16) and
custom_kernel_constraint and
args.masked_softmax_fusion):
if args.rank == 0:
print('WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default'
' back to unfused kernel invocations.', flush=True)
# Always build on rank zero first.
if torch.distributed.get_rank() == 0:
start_time = time.time()
print('> compiling and loading fused kernels ...', flush=True)
fused_kernels.load(args)
torch.distributed.barrier()
else:
torch.distributed.barrier()
fused_kernels.load(args)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# the lock is released.
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>>> done with compiling and loading fused kernels. '
'Compilation time: {:.3f} seconds'.format(
time.time() - start_time), flush=True)
def _initialize_distributed(): def _initialize_distributed():
"""Initialize torch.distributed and mpu.""" """Initialize torch.distributed and mpu."""
......
...@@ -13,23 +13,7 @@ ...@@ -13,23 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
_LAYER_NORM = None from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
def import_layernorm(fp32_residual_connection, bf16):
global _LAYER_NORM
if not _LAYER_NORM:
if bf16:
from torch.nn import LayerNorm
elif fp32_residual_connection:
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
else:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
_LAYER_NORM = LayerNorm
return _LAYER_NORM
from .distributed import * from .distributed import *
from .bert_model import (BertModel, from .bert_model import (BertModel,
......
...@@ -22,7 +22,7 @@ from megatron import mpu ...@@ -22,7 +22,7 @@ from megatron import mpu
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model import import_layernorm from megatron.model import LayerNorm
from megatron.model.utils import openai_gelu, erf_gelu from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
...@@ -78,7 +78,6 @@ class BertLMHead(MegatronModule): ...@@ -78,7 +78,6 @@ class BertLMHead(MegatronModule):
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.gelu = torch.nn.functional.gelu self.gelu = torch.nn.functional.gelu
if args.openai_gelu: if args.openai_gelu:
......
...@@ -15,29 +15,23 @@ ...@@ -15,29 +15,23 @@
"""This code is copied fron NVIDIA apex: """This code is copied fron NVIDIA apex:
https://github.com/NVIDIA/apex https://github.com/NVIDIA/apex
with minor changes. """ with some changes. """
import math
import torch
import numbers import numbers
import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
from torch.nn import functional as F
import importlib import importlib
global fused_layer_norm_cuda
fused_layer_norm_cuda = None
global fused_mix_prec_layer_norm_cuda global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = None fused_mix_prec_layer_norm_cuda = None
class FusedLayerNormAffineFunction(torch.autograd.Function): class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps): def forward(ctx, input, weight, bias, normalized_shape, eps):
global fused_mix_prec_layer_norm_cuda
if fused_mix_prec_layer_norm_cuda is None:
fused_mix_prec_layer_norm_cuda = importlib.import_module("fused_mix_prec_layer_norm_cuda")
ctx.normalized_shape = normalized_shape ctx.normalized_shape = normalized_shape
ctx.eps = eps ctx.eps = eps
input_ = input.contiguous() input_ = input.contiguous()
...@@ -46,134 +40,51 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): ...@@ -46,134 +40,51 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps) input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar) ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias = fused_mix_prec_layer_norm_cuda.backward_affine( grad_input, grad_weight, grad_bias \
= fused_mix_prec_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar, grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape, input_, ctx.normalized_shape,
weight_, bias_, ctx.eps) weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
class FusedLayerNormFunction(torch.autograd.Function):
@staticmethod return grad_input, grad_weight, grad_bias, None, None
def forward(ctx, input, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward(
input_, ctx.normalized_shape, ctx.eps)
ctx.save_for_backward(input_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, mean, invvar = ctx.saved_tensors
grad_input = None
grad_input = fused_layer_norm_cuda.backward(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
ctx.eps)
return grad_input, None, None
def fused_layer_norm_affine(input, normalized_shape, weight, bias, eps=1e-6):
return FusedLayerNormAffineFunction.apply(input, weight, bias, normalized_shape, eps)
def fused_layer_norm(input, normalized_shape, eps=1e-6):
return FusedLayerNormFunction.apply(input, normalized_shape, eps)
class MixedFusedLayerNorm(torch.nn.Module): class MixedFusedLayerNorm(torch.nn.Module):
r"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ . def __init__(self, normalized_shape, eps=1e-5):
Currently only runs on cuda() tensors.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, Layer Normalization applies per-element scale and
bias with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1]
\times \ldots \times \text{normalized}\_\text{shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)
Examples::
>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = apex.normalization.FusedLayerNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = apex.normalization.FusedLayerNorm(10)
>>> # Activating the module
>>> output = m(input)
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
"""
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super(MixedFusedLayerNorm, self).__init__() super(MixedFusedLayerNorm, self).__init__()
global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
global fused_mix_prec_layer_norm_cuda global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = importlib.import_module("fused_mix_prec_layer_norm_cuda") fused_mix_prec_layer_norm_cuda = importlib.import_module(
"fused_mix_prec_layer_norm_cuda")
if isinstance(normalized_shape, numbers.Integral): if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,) normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape) self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps self.eps = eps
self.elementwise_affine = elementwise_affine self.weight = Parameter(torch.Tensor(*normalized_shape))
if self.elementwise_affine: self.bias = Parameter(torch.Tensor(*normalized_shape))
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self):
if self.elementwise_affine: def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias) init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
if not input.is_cuda:
return F.layer_norm( def forward(self, input):
input, self.normalized_shape, self.weight, self.bias, self.eps)
return FusedLayerNormAffineFunction.apply(
if self.elementwise_affine: input, self.weight, self.bias, self.normalized_shape,self.eps)
return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape,self.eps)
else:
return FusedLayerNormFunction.apply(input, self.normalized_shape, self.eps)
def extra_repr(self):
return '{normalized_shape}, eps={eps}, ' \
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
...@@ -96,6 +96,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -96,6 +96,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
def __init__( def __init__(
self, self,
input_in_fp16, input_in_fp16,
input_in_bf16,
attn_mask_type, attn_mask_type,
scaled_masked_softmax_fusion, scaled_masked_softmax_fusion,
mask_func, mask_func,
...@@ -104,6 +105,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -104,6 +105,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
): ):
super(FusedScaleMaskSoftmax, self).__init__() super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16 self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
assert not (self.input_in_fp16 and self.input_in_bf16),\
'both fp16 and bf16 flags cannot be active at the same time.'
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func self.mask_func = mask_func
...@@ -128,8 +133,8 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -128,8 +133,8 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
query_seq_len % 4 == 0 and attn_batch_size % 4 == 0 query_seq_len % 4 == 0 and attn_batch_size % 4 == 0
# invoke custom kernel # invoke custom kernel
if self.input_in_fp16 and mask is not None and \ if self.input_in_float16 and mask is not None and \
custom_kernel_constraint and self.scaled_masked_softmax_fusion: custom_kernel_constraint and self.scaled_masked_softmax_fusion:
scale = self.scale if self.scale is not None else 1.0 scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type == AttnMaskType.causal: if self.attn_mask_type == AttnMaskType.causal:
...@@ -142,7 +147,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -142,7 +147,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert self.attn_mask_type == AttnMaskType.padding assert self.attn_mask_type == AttnMaskType.padding
probs = ScaledMaskedSoftmax.apply(input, mask, scale) probs = ScaledMaskedSoftmax.apply(input, mask, scale)
else: else:
if self.input_in_fp16 and self.softmax_in_fp32: if self.input_in_float16 and self.softmax_in_fp32:
input = input.float() input = input.float()
if self.scale is not None: if self.scale is not None:
...@@ -150,7 +155,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -150,7 +155,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
mask_output = self.mask_func(input, mask) if mask is not None else input mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output) probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_fp16 and self.softmax_in_fp32: if self.input_in_float16 and self.softmax_in_fp32:
probs = probs.half() if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs return probs
...@@ -22,7 +22,7 @@ from megatron import get_args ...@@ -22,7 +22,7 @@ from megatron import get_args
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import AttnMaskType, LayerType, AttnType from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import import_layernorm from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
...@@ -116,6 +116,7 @@ class ParallelAttention(MegatronModule): ...@@ -116,6 +116,7 @@ class ParallelAttention(MegatronModule):
super(ParallelAttention, self).__init__() super(ParallelAttention, self).__init__()
args = get_args() args = get_args()
self.fp16 = args.fp16 self.fp16 = args.fp16
self.bf16 = args.bf16
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
...@@ -164,7 +165,7 @@ class ParallelAttention(MegatronModule): ...@@ -164,7 +165,7 @@ class ParallelAttention(MegatronModule):
self.norm_factor *= coeff self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax( self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.fp16, self.bf16,
self.attn_mask_type, self.attn_mask_type,
args.masked_softmax_fusion, args.masked_softmax_fusion,
attention_mask_func, attention_mask_func,
...@@ -401,7 +402,6 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -401,7 +402,6 @@ class ParallelTransformerLayer(MegatronModule):
self.fp32_residual_connection = args.fp32_residual_connection self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data. # Layernorm on the input data.
LayerNorm = import_layernorm(self.fp32_residual_connection, self.bf16)
self.input_layernorm = LayerNorm( self.input_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
...@@ -443,8 +443,6 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -443,8 +443,6 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
if self.bf16 and self.fp32_residual_connection:
layernorm_output = layernorm_output.bfloat16()
# Self attention. # Self attention.
attention_output, attention_bias = \ attention_output, attention_bias = \
self.self_attention(layernorm_output, self.self_attention(layernorm_output,
...@@ -483,8 +481,6 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -483,8 +481,6 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention. # Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input) layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.bf16 and self.fp32_residual_connection:
layernorm_output = layernorm_output.bfloat16()
if self.layer_type == LayerType.decoder: if self.layer_type == LayerType.decoder:
attention_output, attention_bias = \ attention_output, attention_bias = \
...@@ -507,8 +503,6 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -507,8 +503,6 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the decoder attention # Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input) layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
if self.bf16 and self.fp32_residual_connection:
layernorm_output = layernorm_output.bfloat16()
# MLP. # MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output) mlp_output, mlp_bias = self.mlp(layernorm_output)
...@@ -588,8 +582,6 @@ class ParallelTransformer(MegatronModule): ...@@ -588,8 +582,6 @@ class ParallelTransformer(MegatronModule):
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
# Final layer norm before output. # Final layer norm before output.
LayerNorm = import_layernorm(self.fp32_residual_connection,
self.bf16)
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
...@@ -676,8 +668,6 @@ class ParallelTransformer(MegatronModule): ...@@ -676,8 +668,6 @@ class ParallelTransformer(MegatronModule):
# Reverting data format change [s b h] --> [b s h]. # Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous() hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states) output = self.final_layernorm(hidden_states)
if self.bf16 and self.fp32_residual_connection:
output = output.bfloat16()
else: else:
output = hidden_states output = hidden_states
if get_key_value: if get_key_value:
......
...@@ -17,7 +17,7 @@ from apex.optimizers import FusedAdam as Adam ...@@ -17,7 +17,7 @@ from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD from apex.optimizers import FusedSGD as SGD
from megatron import get_args from megatron import get_args
from megatron.model import import_layernorm from megatron.model import LayerNorm
from .grad_scaler import ConstantGradScaler, DynamicGradScaler from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
...@@ -27,8 +27,6 @@ def _get_params_for_weight_decay_optimization(modules): ...@@ -27,8 +27,6 @@ def _get_params_for_weight_decay_optimization(modules):
"""Divide params into with-weight-decay and without-weight-decay groups. """Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will. Layernorms and baises will have no weight decay but the rest will.
""" """
args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16)
weight_decay_params = {'params': []} weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0} no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
......
...@@ -224,16 +224,6 @@ def get_model(model_provider_func): ...@@ -224,16 +224,6 @@ def get_model(model_provider_func):
# Fp16 conversion. # Fp16 conversion.
if args.fp16 or args.bf16: if args.fp16 or args.bf16:
model = [Float16Module(model_module, args) for model_module in model] model = [Float16Module(model_module, args) for model_module in model]
# For now, the layer norm does not support input float32 and outut bf16.
# For this, we move layernorm parameters to fp32 and cast output of the
# layernorm operation back to bf16.
if args.bf16 and args.fp32_residual_connection:
from megatron.model import import_layernorm
LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16)
for model_ in model:
for module_ in model_.modules():
if isinstance(module_, LayerNorm):
module_.float()
if args.DDP_impl == 'torch': if args.DDP_impl == 'torch':
i = torch.cuda.current_device() i = torch.cuda.current_device()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment