Commit b5726555 authored by Sangkug Lym's avatar Sangkug Lym
Browse files

support fp32 training and fix embedding update

parent 83b1e42f
......@@ -541,7 +541,7 @@ def _add_training_args(parser):
'size is supported.')
group.add_argument('--no-gradient-accumulation-fusion',
action='store_false',
help='Disable fuisng gradient accumulation to weight '
help='Disable fusing gradient accumulation to weight '
'gradient computation of linear layers',
dest='gradient_accumulation_fusion')
return parser
......
......@@ -31,7 +31,7 @@ void wgrad_gemm_accum_fp32(const at::Tensor input, const at::Tensor d_output, at
int in_dim = input_2d.size(1);
int out_dim = d_weight.size(0);
DISPATCH_HALF_AND_BFLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp32",
DISPATCH_HALF_BFLOAT_AND_FLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp32",
int result = wgrad_gemm_accum_fp32_cuda<scalar_t>(
input_2d.data_ptr<scalar_t>(),
d_output_2d.data_ptr<scalar_t>(),
......
......@@ -87,6 +87,44 @@ cublasStatus_t gemmex_wrapper(
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
// FP32 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
float* A,
int lda,
float* B,
int ldb,
const float* beta,
float* C,
int ldc) {
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_32F,
lda,
B,
CUDA_R_32F,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
template <typename T>
int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
......@@ -116,3 +154,4 @@ int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_di
template int wgrad_gemm_accum_fp32_cuda<at::Half>(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
template int wgrad_gemm_accum_fp32_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
template int wgrad_gemm_accum_fp32_cuda<float>(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
......@@ -39,6 +39,32 @@
}
#define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
......
......@@ -168,21 +168,14 @@ class DistributedDataParallel(DistributedDataParallelBase):
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
def param_hook(*unused):
if not self.skip_gradient_func(param):
# Add the gradient to the buffer.
if param.grad.data is not None:
# The gradient function of linear layers is fused with GEMMs
param.main_grad.add_(param.grad.data)
# Now we can deallocate grad memory.
param.grad = None
# Add the gradient to the buffer.
if param.grad is not None:
# The gradient function of linear layers is fused with GEMMs
param.main_grad.add_(param.grad.data)
# Now we can deallocate grad memory.
param.grad = None
return param_hook
def skip_gradient_func(self, param):
# Skip gradient function of linear layers
# Gradient accumulation is fused to weight gradient computation operators
if getattr(param, 'fuse_gradient_accumulation', False):
return True
return False
def zero_grad_buffer(self):
"""Set the grad buffer data to zero. Needs to be called at the
......
......@@ -175,8 +175,6 @@ class VocabParallelEmbedding(torch.nn.Module):
device=torch.cuda.current_device(), dtype=args.params_dtype))
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=1)
setattr(self.weight, 'fuse_gradient_accumulation',
args.gradient_accumulation_fusion)
def forward(self, input_):
if self.tensor_model_parallel_size > 1:
......@@ -241,7 +239,6 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
fused_dense_cuda.wgrad_gemm_accum_fp32(input, grad_output, weight.main_grad)
grad_weight = None
else:
# Matrix multiply with asynchronous all-reduce execution
grad_weight = grad_output.t().matmul(input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce:
......@@ -327,8 +324,6 @@ class ColumnParallelLinear(torch.nn.Module):
args.async_tensor_model_parallel_allreduce and
world_size > 1)
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
setattr(self.weight, 'fuse_gradient_accumulation',
self.gradient_accumulation_fusion)
def forward(self, input_):
......@@ -431,8 +426,6 @@ class RowParallelLinear(torch.nn.Module):
else:
self.register_parameter('bias', None)
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
setattr(self.weight, 'fuse_gradient_accumulation',
self.gradient_accumulation_fusion)
def forward(self, input_):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment