Commit 0fa7175f authored by Mohammad Shoeybi's avatar Mohammad Shoeybi Committed by Jared Casper
Browse files

Bfloat fused softmax + fused layer norm

parent d9b1c681
...@@ -133,8 +133,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -133,8 +133,7 @@ def parse_args(extra_args_provider=None, defaults={},
if args.bf16: if args.bf16:
assert not args.fp16 assert not args.fp16
args.params_dtype = torch.bfloat16 args.params_dtype = torch.bfloat16
# No fusion is support for bfloat for now # Jitting fusion is not supported for bfloat for now
assert not args.masked_softmax_fusion
assert not args.bias_gelu_fusion assert not args.bias_gelu_fusion
assert not args.bias_dropout_fusion assert not args.bias_dropout_fusion
......
...@@ -82,12 +82,11 @@ def load(args): ...@@ -82,12 +82,11 @@ def load(args):
# Mixed precision fused layer norm. # Mixed precision fused layer norm.
# ================================= # =================================
if args.fp32_residual_connection: extra_cuda_flags = ['-maxrregcount=50']
extra_cuda_flags = ['-maxrregcount=50'] sources=[srcpath / 'layer_norm_cuda.cpp',
sources=[srcpath / 'layer_norm_cuda.cpp', srcpath / 'layer_norm_cuda_kernel.cu']
srcpath / 'layer_norm_cuda_kernel.cu'] fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper(
fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper( "fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags)
"fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags)
def _get_cuda_bare_metal_version(cuda_dir): def _get_cuda_bare_metal_version(cuda_dir):
......
...@@ -24,12 +24,12 @@ ...@@ -24,12 +24,12 @@
#include "compat.h" #include "compat.h"
namespace { namespace {
void compute_n1_n2( void compute_n1_n2(
at::Tensor input, at::Tensor input,
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
int& n1, int& n1,
int& n2) int& n2) {
{
int idiff = input.ndimension() - normalized_shape.size(); int idiff = input.ndimension() - normalized_shape.size();
n2 = 1; n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) { for (int i = 0; i < (int)normalized_shape.size(); ++i) {
...@@ -118,39 +118,33 @@ void cuda_layer_norm( ...@@ -118,39 +118,33 @@ void cuda_layer_norm(
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm(
at::Tensor input,
at::IntArrayRef normalized_shape,
double epsilon) {
CHECK_INPUT(input);
int n1,n2;
check_args(input,normalized_shape,n1,n2);
at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,NULL,NULL,epsilon);
return {output, mean, invvar};
}
std::vector<at::Tensor> layer_norm_affine( std::vector<at::Tensor> layer_norm_affine(
at::Tensor input, at::Tensor input,
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor gamma,
at::Tensor beta, at::Tensor beta,
double epsilon) { double epsilon) {
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(gamma); CHECK_INPUT(gamma);
CHECK_INPUT(beta); CHECK_INPUT(beta);
int n1,n2; int n1, n2;
check_args(input,normalized_shape,gamma,beta,n1,n2); check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor output = at::empty_like(input, input.options().dtype(at::ScalarType::Half));
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type())); at::Tensor output = at::empty_like(
input, gamma.options().dtype(gamma.scalar_type()));
at::Tensor mean = at::empty(
{n1}, input.options().dtype(at::ScalarType::Float));
at::Tensor invvar = at::empty_like(mean); at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon); cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2,
normalized_shape, &gamma, &beta, epsilon);
return {output, mean, invvar}; return {output, mean, invvar};
} }
void cuda_layer_norm_gradient( void cuda_layer_norm_gradient(
at::Tensor* dout, at::Tensor* dout,
at::Tensor* mean, at::Tensor* mean,
...@@ -167,25 +161,6 @@ void cuda_layer_norm_gradient( ...@@ -167,25 +161,6 @@ void cuda_layer_norm_gradient(
at::Tensor* grad_beta at::Tensor* grad_beta
); );
at::Tensor layer_norm_gradient(
at::Tensor dout,
at::Tensor mean,
at::Tensor invvar,
at::Tensor input,
at::IntArrayRef normalized_shape,
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
int n1,n2;
check_args(input,normalized_shape,n1,n2);
at::Tensor grad_input = at::empty_like(input);
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
normalized_shape,NULL,NULL,epsilon,
&grad_input,NULL,NULL);
return grad_input;
}
std::vector<at::Tensor> layer_norm_gradient_affine( std::vector<at::Tensor> layer_norm_gradient_affine(
at::Tensor dout, at::Tensor dout,
at::Tensor mean, at::Tensor mean,
...@@ -195,26 +170,32 @@ std::vector<at::Tensor> layer_norm_gradient_affine( ...@@ -195,26 +170,32 @@ std::vector<at::Tensor> layer_norm_gradient_affine(
at::Tensor gamma, at::Tensor gamma,
at::Tensor beta, at::Tensor beta,
double epsilon) { double epsilon) {
CHECK_INPUT(dout); CHECK_INPUT(dout);
CHECK_INPUT(mean); CHECK_INPUT(mean);
CHECK_INPUT(invvar); CHECK_INPUT(invvar);
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(gamma); CHECK_INPUT(gamma);
CHECK_INPUT(beta); CHECK_INPUT(beta);
int n1,n2; int n1, n2;
check_args(input,normalized_shape,gamma,beta,n1,n2); check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor grad_input = at::empty_like(input); at::Tensor grad_input = at::empty_like(input);
at::Tensor grad_gamma = at::empty_like(gamma); at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta); at::Tensor grad_beta = at::empty_like(beta);
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon, cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
&grad_input,&grad_gamma,&grad_beta); normalized_shape, &gamma, &beta, epsilon,
&grad_input, &grad_gamma, &grad_beta);
return {grad_input, grad_gamma, grad_beta}; return {grad_input, grad_gamma, grad_beta};
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); m.def("forward_affine", &layer_norm_affine,
m.def("forward", &layer_norm, "LayerNorm forward (CUDA)"); "LayerNorm forward (CUDA)");
m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)"); m.def("backward_affine", &layer_norm_gradient_affine,
m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)"); "LayerNorm backward (CUDA)");
} }
...@@ -285,15 +285,6 @@ struct SharedMemory <float> ...@@ -285,15 +285,6 @@ struct SharedMemory <float>
} }
}; };
template <>
struct SharedMemory <double>
{
__device__ double *getPointer()
{
extern __shared__ double s_double[];
return s_double;
}
};
} }
template<typename T, typename U, typename V> __global__ template<typename T, typename U, typename V> __global__
...@@ -656,6 +647,9 @@ void cuComputeGradInput( ...@@ -656,6 +647,9 @@ void cuComputeGradInput(
} }
} }
template<typename T, typename U, typename V> template<typename T, typename U, typename V>
void HostApplyLayerNorm( void HostApplyLayerNorm(
V* output, V* output,
...@@ -671,7 +665,8 @@ void HostApplyLayerNorm( ...@@ -671,7 +665,8 @@ void HostApplyLayerNorm(
{ {
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
const dim3 threads(32,4,1); const dim3 threads(32,4,1);
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared = int nshared =
threads.y > 1 ? threads.y > 1 ?
...@@ -687,6 +682,7 @@ void HostApplyLayerNorm( ...@@ -687,6 +682,7 @@ void HostApplyLayerNorm(
gamma,beta); gamma,beta);
} }
void cuda_layer_norm( void cuda_layer_norm(
at::Tensor* output, at::Tensor* output,
at::Tensor* mean, at::Tensor* mean,
...@@ -704,21 +700,21 @@ void cuda_layer_norm( ...@@ -704,21 +700,21 @@ void cuda_layer_norm(
double epsilon) double epsilon)
{ {
using namespace at; using namespace at;
DISPATCH_DOUBLE_FLOAT_AND_HALF(input->scalar_type(), 0, "layer_norm_cuda_kernel", DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
using accscalar_t = at::acc_type<scalar_t_0, true>; input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel",
using output_t = at::Half;
HostApplyLayerNorm( HostApplyLayerNorm(
output->DATA_PTR<output_t>(), output->DATA_PTR<scalar_t_out>(),
mean->DATA_PTR<accscalar_t>(), mean->DATA_PTR<float>(),
invvar->DATA_PTR<accscalar_t>(), invvar->DATA_PTR<float>(),
input->DATA_PTR<scalar_t_0>(), input->DATA_PTR<scalar_t_in>(),
n1,n2, n1,n2,
epsilon, epsilon,
gamma != NULL ? gamma->DATA_PTR<output_t>() : NULL, gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
beta != NULL ? beta->DATA_PTR<output_t>() : NULL); beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);
) )
} }
template<typename T, typename U, typename V> template<typename T, typename U, typename V>
void HostLayerNormGradient( void HostLayerNormGradient(
const V* dout, const V* dout,
...@@ -742,10 +738,12 @@ void HostLayerNormGradient( ...@@ -742,10 +738,12 @@ void HostLayerNormGradient(
const int part_size = 16; const int part_size = 16;
const dim3 threads2(32,4,1); const dim3 threads2(32,4,1);
const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1); const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y *
(threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U); const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input->scalar_type())); at::Tensor part_grad_gamma = at::empty(
{part_size,n2}, input->options().dtype(at::ScalarType::Float));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>( cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout, dout,
...@@ -770,7 +768,8 @@ void HostLayerNormGradient( ...@@ -770,7 +768,8 @@ void HostLayerNormGradient(
} }
// compute grad_input // compute grad_input
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32,4,1); const dim3 threads1(32,4,1);
int nshared = int nshared =
...@@ -788,6 +787,7 @@ void HostLayerNormGradient( ...@@ -788,6 +787,7 @@ void HostLayerNormGradient(
grad_input); grad_input);
} }
void cuda_layer_norm_gradient( void cuda_layer_norm_gradient(
at::Tensor* dout, at::Tensor* dout,
at::Tensor* mean, at::Tensor* mean,
...@@ -808,22 +808,22 @@ void cuda_layer_norm_gradient( ...@@ -808,22 +808,22 @@ void cuda_layer_norm_gradient(
at::Tensor* grad_beta) at::Tensor* grad_beta)
{ {
using namespace at; using namespace at;
DISPATCH_FLOAT_AND_HALF(input->scalar_type(), 0, "cuComputeGradInput", DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
using accscalar_t = at::acc_type<scalar_t_0, true>; input->scalar_type(), gamma->scalar_type(),
using output_t = at::Half; "cuda_layer_norm_gradient_kernel",
HostLayerNormGradient( HostLayerNormGradient(
dout->DATA_PTR<output_t>(), dout->DATA_PTR<scalar_t_out>(),
mean->DATA_PTR<accscalar_t>(), mean->DATA_PTR<float>(),
invvar->DATA_PTR<accscalar_t>(), invvar->DATA_PTR<float>(),
input, input,
n1,n2, n1,n2,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input. // if gamma Tensor is NULL on input.
gamma != NULL ? gamma->DATA_PTR<output_t>() : NULL, gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? beta->DATA_PTR<output_t>() : NULL, gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
epsilon, epsilon,
grad_input->DATA_PTR<scalar_t_0>(), grad_input->DATA_PTR<scalar_t_in>(),
gamma != NULL ? grad_gamma->DATA_PTR<output_t>() : NULL, gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? grad_beta->DATA_PTR<output_t>() : NULL); gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
) )
} }
...@@ -37,8 +37,9 @@ torch::Tensor fwd( ...@@ -37,8 +37,9 @@ torch::Tensor fwd(
torch::Tensor const& mask, torch::Tensor const& mask,
float scale_factor) { float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM(input.scalar_type() == at::ScalarType::Half, AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
"Only HALF is supported"); (input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
return fwd_cuda(input, mask, scale_factor); return fwd_cuda(input, mask, scale_factor);
...@@ -52,10 +53,12 @@ torch::Tensor bwd( ...@@ -52,10 +53,12 @@ torch::Tensor bwd(
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM(output_grads.scalar_type() == at::ScalarType::Half, AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
"Only HALF is supported"); (output_grads.scalar_type() == at::ScalarType::BFloat16),
AT_ASSERTM(softmax_results.scalar_type() == at::ScalarType::Half, "Only fp16 and bf16 are supported");
"Only HALF is supported"); AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor); return bwd_cuda(output_grads, softmax_results, scale_factor);
} }
......
...@@ -30,10 +30,16 @@ template <typename Datatype, int ELEMENTS_PER_LDG> ...@@ -30,10 +30,16 @@ template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <> template <>
__device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; } __device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
template <> template <>
__device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); } __device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
template <> template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; } __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "scaled_masked_softmax.h" #include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
...@@ -55,16 +56,20 @@ torch::Tensor fwd_cuda( ...@@ -55,16 +56,20 @@ torch::Tensor fwd_cuda(
void* mask_ptr = static_cast<void*>(mask.data_ptr()); void* mask_ptr = static_cast<void*>(mask.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
dispatch_scaled_masked_softmax_forward<half, half, float>( DISPATCH_HALF_AND_BFLOAT(
reinterpret_cast<half*>(softmax_results_ptr), input.scalar_type(),
reinterpret_cast<const half*>(input_ptr), "dispatch_scaled_masked_softmax_forward",
reinterpret_cast<const uint8_t*>(mask_ptr), dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
scale_factor, reinterpret_cast<scalar_t*>(softmax_results_ptr),
query_seq_len, reinterpret_cast<const scalar_t*>(input_ptr),
key_seq_len, reinterpret_cast<const uint8_t*>(mask_ptr),
batches, scale_factor,
attn_heads, query_seq_len,
pad_batches); key_seq_len,
batches,
attn_heads,
pad_batches);
);
return softmax_results; return softmax_results;
} }
...@@ -85,15 +90,19 @@ torch::Tensor bwd_cuda( ...@@ -85,15 +90,19 @@ torch::Tensor bwd_cuda(
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr()); void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad //Softmax Grad
dispatch_scaled_masked_softmax_backward<half, half, float>( DISPATCH_HALF_AND_BFLOAT(
reinterpret_cast<half*>(output_grads_ptr), output_grads_.scalar_type(),
reinterpret_cast<half*>(output_grads_ptr), "dispatch_scaled_masked_softmax_backward",
reinterpret_cast<half const*>(softmax_results.data_ptr()), dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
scale_factor, reinterpret_cast<scalar_t*>(output_grads_ptr),
query_seq_len, reinterpret_cast<scalar_t*>(output_grads_ptr),
key_seq_len, reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
batches, scale_factor,
attn_heads); query_seq_len,
key_seq_len,
batches,
attn_heads);
);
//backward pass is completely in-place //backward pass is completely in-place
return output_grads; return output_grads;
......
...@@ -33,8 +33,9 @@ torch::Tensor bwd_cuda( ...@@ -33,8 +33,9 @@ torch::Tensor bwd_cuda(
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.scalar_type() == at::ScalarType::Half, AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
"Only HALF is supported"); (input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor); return fwd_cuda(input, scale_factor);
} }
...@@ -47,10 +48,12 @@ torch::Tensor bwd( ...@@ -47,10 +48,12 @@ torch::Tensor bwd(
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.scalar_type() == at::ScalarType::Half, AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
"Only HALF is supported"); (output_grads.scalar_type() == at::ScalarType::BFloat16),
AT_ASSERTM(softmax_results.scalar_type() == at::ScalarType::Half, "Only fp16 and bf16 are supported");
"Only HALF is supported"); AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor); return bwd_cuda(output_grads, softmax_results, scale_factor);
} }
...@@ -61,7 +64,7 @@ torch::Tensor bwd( ...@@ -61,7 +64,7 @@ torch::Tensor bwd(
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", m.def("forward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward."); "Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward", m.def("backward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
#include <cfloat> #include <cfloat>
#include <limits> #include <limits>
#include <stdint.h> #include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
namespace { namespace {
...@@ -30,10 +29,16 @@ template <typename Datatype, int ELEMENTS_PER_LDG> ...@@ -30,10 +29,16 @@ template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <> template <>
__device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; } __device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
template <> template <>
__device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); } __device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
template <> template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; } __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
...@@ -45,10 +50,16 @@ template <typename Datatype, int ELEMENTS_PER_LDG> ...@@ -45,10 +50,16 @@ template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst); __device__ __inline__ void copy_zero_vector(Datatype *dst);
template <> template <>
__device__ __inline__ void copy_zero_vector<__half, 1>(__half *dst) { *dst = 0.0; } __device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16 *dst) { *dst = 0.0; }
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { *dst = 0.0; }
template <> template <>
__device__ __inline__ void copy_zero_vector<__half, 4>(__half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } __device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
int log2_ceil(int value) { int log2_ceil(int value) {
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h" #include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
...@@ -45,15 +46,20 @@ torch::Tensor fwd_cuda( ...@@ -45,15 +46,20 @@ torch::Tensor fwd_cuda(
void* input_ptr = static_cast<void*>(input.data_ptr()); void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
dispatch_scaled_upper_triang_masked_softmax_forward<half, half, float>( DISPATCH_HALF_AND_BFLOAT(
reinterpret_cast<half*>(softmax_results_ptr), input.scalar_type(),
reinterpret_cast<const half*>(input_ptr), "dispatch_scaled_upper_triang_masked_softmax_forward",
scale_factor, dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
seq_len, reinterpret_cast<scalar_t*>(softmax_results_ptr),
seq_len, reinterpret_cast<const scalar_t*>(input_ptr),
attn_batches); scale_factor,
seq_len,
seq_len,
attn_batches);
);
return softmax_results; return softmax_results;
} }
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_, torch::Tensor const& output_grads_,
...@@ -71,14 +77,18 @@ torch::Tensor bwd_cuda( ...@@ -71,14 +77,18 @@ torch::Tensor bwd_cuda(
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr()); void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad //Softmax Grad
dispatch_scaled_upper_triang_masked_softmax_backward<half, half, float>( DISPATCH_HALF_AND_BFLOAT(
reinterpret_cast<half*>(output_grads_ptr), output_grads_.scalar_type(),
reinterpret_cast<half*>(output_grads_ptr), "dispatch_scaled_upper_triang_masked_softmax_backward",
reinterpret_cast<half const*>(softmax_results.data_ptr()), dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
scale_factor, reinterpret_cast<scalar_t*>(output_grads_ptr),
seq_len, reinterpret_cast<scalar_t*>(output_grads_ptr),
seq_len, reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
attn_batches); scale_factor,
seq_len,
seq_len,
attn_batches);
);
//backward pass is completely in-place //backward pass is completely in-place
return output_grads; return output_grads;
......
...@@ -14,214 +14,78 @@ ...@@ -14,214 +14,78 @@
* limitations under the License. * limitations under the License.
*/ */
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include "compat.h" #include "compat.h"
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch(TYPE) \ switch(TYPE) \
{ \ { \
case at::ScalarType::Float: \ case at::ScalarType::Half: \
{ \ { \
using scalar_t_##LEVEL = float; \ using scalar_t = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Half: \ case at::ScalarType::BFloat16: \
{ \ { \
using scalar_t_##LEVEL = at::Half; \ using scalar_t = at::BFloat16; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
default: \ default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} }
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch(TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: \
{ \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes
(T *x,
T val,
int lanes=1,
bool share_result=false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = x[tid] + x[tid+i];
__syncthreads();
}
T final;
if(tid < 32)
{
if(blockSize >= 64)
final = x[tid] + x[tid+32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op
(T *x,
T val,
int lanes=1,
bool share_result=false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid+i]));
__syncthreads();
}
T final;
if(tid < 32)
{
if(blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid+32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
...@@ -13,23 +13,7 @@ ...@@ -13,23 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
_LAYER_NORM = None from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
def import_layernorm(fp32_residual_connection, bf16):
global _LAYER_NORM
if not _LAYER_NORM:
if bf16:
from torch.nn import LayerNorm
elif fp32_residual_connection:
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
else:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
_LAYER_NORM = LayerNorm
return _LAYER_NORM
from .distributed import * from .distributed import *
from .bert_model import (BertModel, from .bert_model import (BertModel,
......
...@@ -22,7 +22,7 @@ from megatron import mpu ...@@ -22,7 +22,7 @@ from megatron import mpu
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model import import_layernorm from megatron.model import LayerNorm
from megatron.model.utils import openai_gelu, erf_gelu from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
...@@ -78,7 +78,6 @@ class BertLMHead(MegatronModule): ...@@ -78,7 +78,6 @@ class BertLMHead(MegatronModule):
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.gelu = torch.nn.functional.gelu self.gelu = torch.nn.functional.gelu
if args.openai_gelu: if args.openai_gelu:
......
...@@ -15,29 +15,23 @@ ...@@ -15,29 +15,23 @@
"""This code is copied fron NVIDIA apex: """This code is copied fron NVIDIA apex:
https://github.com/NVIDIA/apex https://github.com/NVIDIA/apex
with minor changes. """ with some changes. """
import math
import torch
import numbers import numbers
import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
from torch.nn import functional as F
import importlib import importlib
global fused_layer_norm_cuda
fused_layer_norm_cuda = None
global fused_mix_prec_layer_norm_cuda global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = None fused_mix_prec_layer_norm_cuda = None
class FusedLayerNormAffineFunction(torch.autograd.Function): class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps): def forward(ctx, input, weight, bias, normalized_shape, eps):
global fused_mix_prec_layer_norm_cuda
if fused_mix_prec_layer_norm_cuda is None:
fused_mix_prec_layer_norm_cuda = importlib.import_module("fused_mix_prec_layer_norm_cuda")
ctx.normalized_shape = normalized_shape ctx.normalized_shape = normalized_shape
ctx.eps = eps ctx.eps = eps
input_ = input.contiguous() input_ = input.contiguous()
...@@ -46,134 +40,51 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): ...@@ -46,134 +40,51 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps) input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar) ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias = fused_mix_prec_layer_norm_cuda.backward_affine( grad_input, grad_weight, grad_bias \
= fused_mix_prec_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar, grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape, input_, ctx.normalized_shape,
weight_, bias_, ctx.eps) weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
class FusedLayerNormFunction(torch.autograd.Function):
@staticmethod return grad_input, grad_weight, grad_bias, None, None
def forward(ctx, input, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward(
input_, ctx.normalized_shape, ctx.eps)
ctx.save_for_backward(input_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, mean, invvar = ctx.saved_tensors
grad_input = None
grad_input = fused_layer_norm_cuda.backward(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
ctx.eps)
return grad_input, None, None
def fused_layer_norm_affine(input, normalized_shape, weight, bias, eps=1e-6):
return FusedLayerNormAffineFunction.apply(input, weight, bias, normalized_shape, eps)
def fused_layer_norm(input, normalized_shape, eps=1e-6):
return FusedLayerNormFunction.apply(input, normalized_shape, eps)
class MixedFusedLayerNorm(torch.nn.Module): class MixedFusedLayerNorm(torch.nn.Module):
r"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ . def __init__(self, normalized_shape, eps=1e-5):
Currently only runs on cuda() tensors.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, Layer Normalization applies per-element scale and
bias with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1]
\times \ldots \times \text{normalized}\_\text{shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)
Examples::
>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = apex.normalization.FusedLayerNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = apex.normalization.FusedLayerNorm(10)
>>> # Activating the module
>>> output = m(input)
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
"""
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super(MixedFusedLayerNorm, self).__init__() super(MixedFusedLayerNorm, self).__init__()
global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
global fused_mix_prec_layer_norm_cuda global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = importlib.import_module("fused_mix_prec_layer_norm_cuda") fused_mix_prec_layer_norm_cuda = importlib.import_module(
"fused_mix_prec_layer_norm_cuda")
if isinstance(normalized_shape, numbers.Integral): if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,) normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape) self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps self.eps = eps
self.elementwise_affine = elementwise_affine self.weight = Parameter(torch.Tensor(*normalized_shape))
if self.elementwise_affine: self.bias = Parameter(torch.Tensor(*normalized_shape))
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self):
if self.elementwise_affine: def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias) init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
if not input.is_cuda:
return F.layer_norm( def forward(self, input):
input, self.normalized_shape, self.weight, self.bias, self.eps)
return FusedLayerNormAffineFunction.apply(
if self.elementwise_affine: input, self.weight, self.bias, self.normalized_shape,self.eps)
return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape,self.eps)
else:
return FusedLayerNormFunction.apply(input, self.normalized_shape, self.eps)
def extra_repr(self):
return '{normalized_shape}, eps={eps}, ' \
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
...@@ -96,6 +96,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -96,6 +96,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
def __init__( def __init__(
self, self,
input_in_fp16, input_in_fp16,
input_in_bf16,
attn_mask_type, attn_mask_type,
scaled_masked_softmax_fusion, scaled_masked_softmax_fusion,
mask_func, mask_func,
...@@ -104,6 +105,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -104,6 +105,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
): ):
super(FusedScaleMaskSoftmax, self).__init__() super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16 self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
assert not (self.input_in_fp16 and self.input_in_bf16),\
'both fp16 and bf16 flags cannot be active at the same time.'
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func self.mask_func = mask_func
...@@ -128,8 +133,8 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -128,8 +133,8 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
query_seq_len % 4 == 0 and attn_batch_size % 4 == 0 query_seq_len % 4 == 0 and attn_batch_size % 4 == 0
# invoke custom kernel # invoke custom kernel
if self.input_in_fp16 and mask is not None and \ if self.input_in_float16 and mask is not None and \
custom_kernel_constraint and self.scaled_masked_softmax_fusion: custom_kernel_constraint and self.scaled_masked_softmax_fusion:
scale = self.scale if self.scale is not None else 1.0 scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type == AttnMaskType.causal: if self.attn_mask_type == AttnMaskType.causal:
...@@ -142,7 +147,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -142,7 +147,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert self.attn_mask_type == AttnMaskType.padding assert self.attn_mask_type == AttnMaskType.padding
probs = ScaledMaskedSoftmax.apply(input, mask, scale) probs = ScaledMaskedSoftmax.apply(input, mask, scale)
else: else:
if self.input_in_fp16 and self.softmax_in_fp32: if self.input_in_float16 and self.softmax_in_fp32:
input = input.float() input = input.float()
if self.scale is not None: if self.scale is not None:
...@@ -150,7 +155,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -150,7 +155,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
mask_output = self.mask_func(input, mask) if mask is not None else input mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output) probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_fp16 and self.softmax_in_fp32: if self.input_in_float16 and self.softmax_in_fp32:
probs = probs.half() if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs return probs
...@@ -22,7 +22,7 @@ from megatron import get_args ...@@ -22,7 +22,7 @@ from megatron import get_args
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import AttnMaskType, LayerType, AttnType from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import import_layernorm from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
...@@ -116,6 +116,7 @@ class ParallelAttention(MegatronModule): ...@@ -116,6 +116,7 @@ class ParallelAttention(MegatronModule):
super(ParallelAttention, self).__init__() super(ParallelAttention, self).__init__()
args = get_args() args = get_args()
self.fp16 = args.fp16 self.fp16 = args.fp16
self.bf16 = args.bf16
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
...@@ -164,7 +165,7 @@ class ParallelAttention(MegatronModule): ...@@ -164,7 +165,7 @@ class ParallelAttention(MegatronModule):
self.norm_factor *= coeff self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax( self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.fp16, self.bf16,
self.attn_mask_type, self.attn_mask_type,
args.masked_softmax_fusion, args.masked_softmax_fusion,
attention_mask_func, attention_mask_func,
...@@ -401,7 +402,6 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -401,7 +402,6 @@ class ParallelTransformerLayer(MegatronModule):
self.fp32_residual_connection = args.fp32_residual_connection self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data. # Layernorm on the input data.
LayerNorm = import_layernorm(self.fp32_residual_connection, self.bf16)
self.input_layernorm = LayerNorm( self.input_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
...@@ -443,8 +443,6 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -443,8 +443,6 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
if self.bf16 and self.fp32_residual_connection:
layernorm_output = layernorm_output.bfloat16()
# Self attention. # Self attention.
attention_output, attention_bias = \ attention_output, attention_bias = \
self.self_attention(layernorm_output, self.self_attention(layernorm_output,
...@@ -483,8 +481,6 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -483,8 +481,6 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention. # Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input) layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.bf16 and self.fp32_residual_connection:
layernorm_output = layernorm_output.bfloat16()
if self.layer_type == LayerType.decoder: if self.layer_type == LayerType.decoder:
attention_output, attention_bias = \ attention_output, attention_bias = \
...@@ -507,8 +503,6 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -507,8 +503,6 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the decoder attention # Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input) layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
if self.bf16 and self.fp32_residual_connection:
layernorm_output = layernorm_output.bfloat16()
# MLP. # MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output) mlp_output, mlp_bias = self.mlp(layernorm_output)
...@@ -588,8 +582,6 @@ class ParallelTransformer(MegatronModule): ...@@ -588,8 +582,6 @@ class ParallelTransformer(MegatronModule):
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
# Final layer norm before output. # Final layer norm before output.
LayerNorm = import_layernorm(self.fp32_residual_connection,
self.bf16)
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
...@@ -676,8 +668,6 @@ class ParallelTransformer(MegatronModule): ...@@ -676,8 +668,6 @@ class ParallelTransformer(MegatronModule):
# Reverting data format change [s b h] --> [b s h]. # Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous() hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states) output = self.final_layernorm(hidden_states)
if self.bf16 and self.fp32_residual_connection:
output = output.bfloat16()
else: else:
output = hidden_states output = hidden_states
if get_key_value: if get_key_value:
......
...@@ -17,7 +17,7 @@ from apex.optimizers import FusedAdam as Adam ...@@ -17,7 +17,7 @@ from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD from apex.optimizers import FusedSGD as SGD
from megatron import get_args from megatron import get_args
from megatron.model import import_layernorm from megatron.model import LayerNorm
from .grad_scaler import ConstantGradScaler, DynamicGradScaler from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
...@@ -27,8 +27,6 @@ def _get_params_for_weight_decay_optimization(modules): ...@@ -27,8 +27,6 @@ def _get_params_for_weight_decay_optimization(modules):
"""Divide params into with-weight-decay and without-weight-decay groups. """Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will. Layernorms and baises will have no weight decay but the rest will.
""" """
args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16)
weight_decay_params = {'params': []} weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0} no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
......
...@@ -224,16 +224,6 @@ def get_model(model_provider_func): ...@@ -224,16 +224,6 @@ def get_model(model_provider_func):
# Fp16 conversion. # Fp16 conversion.
if args.fp16 or args.bf16: if args.fp16 or args.bf16:
model = [Float16Module(model_module, args) for model_module in model] model = [Float16Module(model_module, args) for model_module in model]
# For now, the layer norm does not support input float32 and outut bf16.
# For this, we move layernorm parameters to fp32 and cast output of the
# layernorm operation back to bf16.
if args.bf16 and args.fp32_residual_connection:
from megatron.model import import_layernorm
LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16)
for model_ in model:
for module_ in model_.modules():
if isinstance(module_, LayerNorm):
module_.float()
if args.DDP_impl == 'torch': if args.DDP_impl == 'torch':
i = torch.cuda.current_device() i = torch.cuda.current_device()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment