Commit b5726555 authored by Sangkug Lym's avatar Sangkug Lym
Browse files

support fp32 training and fix embedding update

parent 83b1e42f
...@@ -541,7 +541,7 @@ def _add_training_args(parser): ...@@ -541,7 +541,7 @@ def _add_training_args(parser):
'size is supported.') 'size is supported.')
group.add_argument('--no-gradient-accumulation-fusion', group.add_argument('--no-gradient-accumulation-fusion',
action='store_false', action='store_false',
help='Disable fuisng gradient accumulation to weight ' help='Disable fusing gradient accumulation to weight '
'gradient computation of linear layers', 'gradient computation of linear layers',
dest='gradient_accumulation_fusion') dest='gradient_accumulation_fusion')
return parser return parser
......
...@@ -31,7 +31,7 @@ void wgrad_gemm_accum_fp32(const at::Tensor input, const at::Tensor d_output, at ...@@ -31,7 +31,7 @@ void wgrad_gemm_accum_fp32(const at::Tensor input, const at::Tensor d_output, at
int in_dim = input_2d.size(1); int in_dim = input_2d.size(1);
int out_dim = d_weight.size(0); int out_dim = d_weight.size(0);
DISPATCH_HALF_AND_BFLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp32", DISPATCH_HALF_BFLOAT_AND_FLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp32",
int result = wgrad_gemm_accum_fp32_cuda<scalar_t>( int result = wgrad_gemm_accum_fp32_cuda<scalar_t>(
input_2d.data_ptr<scalar_t>(), input_2d.data_ptr<scalar_t>(),
d_output_2d.data_ptr<scalar_t>(), d_output_2d.data_ptr<scalar_t>(),
......
...@@ -87,6 +87,44 @@ cublasStatus_t gemmex_wrapper( ...@@ -87,6 +87,44 @@ cublasStatus_t gemmex_wrapper(
CUBLAS_GEMM_DEFAULT_TENSOR_OP); CUBLAS_GEMM_DEFAULT_TENSOR_OP);
} }
// FP32 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
float* A,
int lda,
float* B,
int ldb,
const float* beta,
float* C,
int ldc) {
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_32F,
lda,
B,
CUDA_R_32F,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
template <typename T> template <typename T>
int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) { int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
...@@ -116,3 +154,4 @@ int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_di ...@@ -116,3 +154,4 @@ int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_di
template int wgrad_gemm_accum_fp32_cuda<at::Half>(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); template int wgrad_gemm_accum_fp32_cuda<at::Half>(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
template int wgrad_gemm_accum_fp32_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); template int wgrad_gemm_accum_fp32_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
template int wgrad_gemm_accum_fp32_cuda<float>(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
...@@ -39,6 +39,32 @@ ...@@ -39,6 +39,32 @@
} }
#define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \ switch(TYPEIN) \
......
...@@ -168,21 +168,14 @@ class DistributedDataParallel(DistributedDataParallelBase): ...@@ -168,21 +168,14 @@ class DistributedDataParallel(DistributedDataParallelBase):
"""Create the all-reduce hook for backprop.""" """Create the all-reduce hook for backprop."""
# Hook used for back-prop. # Hook used for back-prop.
def param_hook(*unused): def param_hook(*unused):
if not self.skip_gradient_func(param): # Add the gradient to the buffer.
# Add the gradient to the buffer. if param.grad is not None:
if param.grad.data is not None: # The gradient function of linear layers is fused with GEMMs
# The gradient function of linear layers is fused with GEMMs param.main_grad.add_(param.grad.data)
param.main_grad.add_(param.grad.data) # Now we can deallocate grad memory.
# Now we can deallocate grad memory. param.grad = None
param.grad = None
return param_hook return param_hook
def skip_gradient_func(self, param):
# Skip gradient function of linear layers
# Gradient accumulation is fused to weight gradient computation operators
if getattr(param, 'fuse_gradient_accumulation', False):
return True
return False
def zero_grad_buffer(self): def zero_grad_buffer(self):
"""Set the grad buffer data to zero. Needs to be called at the """Set the grad buffer data to zero. Needs to be called at the
......
...@@ -175,8 +175,6 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -175,8 +175,6 @@ class VocabParallelEmbedding(torch.nn.Module):
device=torch.cuda.current_device(), dtype=args.params_dtype)) device=torch.cuda.current_device(), dtype=args.params_dtype))
_initialize_affine_weight_gpu(self.weight, init_method, _initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=1) partition_dim=0, stride=1)
setattr(self.weight, 'fuse_gradient_accumulation',
args.gradient_accumulation_fusion)
def forward(self, input_): def forward(self, input_):
if self.tensor_model_parallel_size > 1: if self.tensor_model_parallel_size > 1:
...@@ -241,7 +239,6 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function): ...@@ -241,7 +239,6 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
fused_dense_cuda.wgrad_gemm_accum_fp32(input, grad_output, weight.main_grad) fused_dense_cuda.wgrad_gemm_accum_fp32(input, grad_output, weight.main_grad)
grad_weight = None grad_weight = None
else: else:
# Matrix multiply with asynchronous all-reduce execution
grad_weight = grad_output.t().matmul(input) grad_weight = grad_output.t().matmul(input)
grad_bias = grad_output.sum(dim=0) if use_bias else None grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce: if ctx.async_grad_allreduce:
...@@ -327,8 +324,6 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -327,8 +324,6 @@ class ColumnParallelLinear(torch.nn.Module):
args.async_tensor_model_parallel_allreduce and args.async_tensor_model_parallel_allreduce and
world_size > 1) world_size > 1)
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
setattr(self.weight, 'fuse_gradient_accumulation',
self.gradient_accumulation_fusion)
def forward(self, input_): def forward(self, input_):
...@@ -431,8 +426,6 @@ class RowParallelLinear(torch.nn.Module): ...@@ -431,8 +426,6 @@ class RowParallelLinear(torch.nn.Module):
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
setattr(self.weight, 'fuse_gradient_accumulation',
self.gradient_accumulation_fusion)
def forward(self, input_): def forward(self, input_):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment