Commit 9c5a830f authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'slym/grad_accum_fusion' into 'main'

Gradient accumulation fusion

See merge request ADLR/megatron-lm!394
parents 0ed2f6ac b5726555
...@@ -172,6 +172,14 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -172,6 +172,14 @@ def parse_args(extra_args_provider=None, defaults={},
if args.accumulate_allreduce_grads_in_fp32: if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
assert args.use_contiguous_buffers_in_local_ddp assert args.use_contiguous_buffers_in_local_ddp
else:
if args.gradient_accumulation_fusion:
args.gradient_accumulation_fusion = False
if args.rank == 0:
print('Gradient accumulation fusion to linear layer weight '
'gradient computation is supported only with fp32 '
'gradient accumulation. Setting gradient_accumulation_fusion '
'to False', flush=True)
# For torch DDP, we do not use contiguous buffer # For torch DDP, we do not use contiguous buffer
if args.DDP_impl == 'torch': if args.DDP_impl == 'torch':
...@@ -521,15 +529,21 @@ def _add_training_args(parser): ...@@ -521,15 +529,21 @@ def _add_training_args(parser):
choices=['single', 'cyclic'], choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader') help='Single pass vs multiple pass data loader')
group.add_argument('--no-async-tensor-model-parallel-allreduce', group.add_argument('--no-async-tensor-model-parallel-allreduce',
action='store_true', action='store_false',
help='Disable asynchronous execution of ' help='Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight ' 'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.') 'gradient compuation of a column-linear layer.',
dest='async_tensor_model_parallel_allreduce')
group.add_argument('--no-persist-layer-norm', action='store_true', group.add_argument('--no-persist-layer-norm', action='store_true',
help='Disable using persistent fused layer norm kernel. ' help='Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please ' 'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden ' 'check persist_ln_hidden_sizes if your hidden '
'size is supported.') 'size is supported.')
group.add_argument('--no-gradient-accumulation-fusion',
action='store_false',
help='Disable fusing gradient accumulation to weight '
'gradient computation of linear layers',
dest='gradient_accumulation_fusion')
return parser return parser
......
...@@ -94,6 +94,16 @@ def load(args): ...@@ -94,6 +94,16 @@ def load(args):
fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper( fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper(
"fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags) "fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags)
# =================================
# Fused gradient accumulation to weight gradient computation of linear layer
# =================================
if args.gradient_accumulation_fusion:
sources=[srcpath / 'fused_weight_gradient_dense.cpp',
srcpath / 'fused_weight_gradient_dense.cu']
fused_dense_cuda = _cpp_extention_load_helper(
"fused_dense_cuda", sources, [])
def _get_cuda_bare_metal_version(cuda_dir): def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
......
#include <torch/torch.h>
#include <torch/extension.h>
#include <vector>
#include <stdio.h>
#include "type_shim.h"
template <typename T>
int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
void wgrad_gemm_accum_fp32(const at::Tensor input, const at::Tensor d_output, at::Tensor d_weight) {
at::Tensor input_2d, d_output_2d;
// input tensor: collapse to the first dim
auto in_sizes = input.sizes();
if (input.dim() > 2) {
input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});
} else {
input_2d = input;
}
// d_output tensor: collapse to the first dim
auto d_out_sizes = d_output.sizes();
if (d_output.dim() > 2) {
d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});
} else {
d_output_2d = d_output;
}
int hidden_dim = input_2d.size(0);
int in_dim = input_2d.size(1);
int out_dim = d_weight.size(0);
DISPATCH_HALF_BFLOAT_AND_FLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp32",
int result = wgrad_gemm_accum_fp32_cuda<scalar_t>(
input_2d.data_ptr<scalar_t>(),
d_output_2d.data_ptr<scalar_t>(),
d_weight.data_ptr<float>(),
in_dim,
hidden_dim,
out_dim);
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32, "wgrad gemm accum in fp32");
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <torch/torch.h>
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
// BF16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::BFloat16* A,
int lda,
at::BFloat16* B,
int ldb,
const float* beta,
float* C,
int ldc) {
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16BF,
lda,
B,
CUDA_R_16BF,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
// FP16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::Half* A,
int lda,
at::Half* B,
int ldb,
const float* beta,
float* C,
int ldc) {
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
// FP32 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
float* A,
int lda,
float* B,
int ldb,
const float* beta,
float* C,
int ldc) {
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_32F,
lda,
B,
CUDA_R_32F,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
template <typename T>
int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream;
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta = 1.0;
int status = 1;
status = gemmex_wrapper(
handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
in_dim,
out_dim,
hidden_dim,
&alpha,
input,
in_dim,
d_output,
out_dim,
&beta,
d_weight,
in_dim);
return status;
}
template int wgrad_gemm_accum_fp32_cuda<at::Half>(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
template int wgrad_gemm_accum_fp32_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
template int wgrad_gemm_accum_fp32_cuda<float>(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
...@@ -39,6 +39,32 @@ ...@@ -39,6 +39,32 @@
} }
#define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \ switch(TYPEIN) \
......
...@@ -164,13 +164,13 @@ class DistributedDataParallel(DistributedDataParallelBase): ...@@ -164,13 +164,13 @@ class DistributedDataParallel(DistributedDataParallelBase):
grad_acc.register_hook(self._make_param_hook(param)) grad_acc.register_hook(self._make_param_hook(param))
self.grad_accs.append(grad_acc) self.grad_accs.append(grad_acc)
def _make_param_hook(self, param): def _make_param_hook(self, param):
"""Create the all-reduce hook for backprop.""" """Create the all-reduce hook for backprop."""
# Hook used for back-prop. # Hook used for back-prop.
def param_hook(*unused): def param_hook(*unused):
# Add the gradient to the buffer. # Add the gradient to the buffer.
if param.grad.data is not None: if param.grad is not None:
# The gradient function of linear layers is fused with GEMMs
param.main_grad.add_(param.grad.data) param.main_grad.add_(param.grad.data)
# Now we can deallocate grad memory. # Now we can deallocate grad memory.
param.grad = None param.grad = None
......
...@@ -29,13 +29,19 @@ from megatron.model.utils import init_method_normal, scaled_init_method_normal ...@@ -29,13 +29,19 @@ from megatron.model.utils import init_method_normal, scaled_init_method_normal
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=None): bias=None):
"""LM logits using word embedding weights.""" """LM logits using word embedding weights."""
args = get_args()
# Parallel logits. # Parallel logits.
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) if args.async_tensor_model_parallel_allreduce:
# Matrix multiply. input_parallel = input_
if bias is None: async_grad_allreduce = mpu.get_tensor_model_parallel_world_size() > 1
logits_parallel = F.linear(input_parallel, word_embeddings_weight)
else: else:
logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias) input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False
# Matrix multiply.
logits_parallel = mpu.LinearWithGradAccumulationAndAsyncAllreduce.apply(
input_parallel, word_embeddings_weight, bias,
args.gradient_accumulation_fusion,
async_grad_allreduce)
# Gather if needed. # Gather if needed.
if parallel_output: if parallel_output:
return logits_parallel return logits_parallel
......
...@@ -49,6 +49,7 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi ...@@ -49,6 +49,7 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
from .initialize import initialize_model_parallel from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized from .initialize import model_parallel_is_initialized
from .layers import LinearWithGradAccumulationAndAsyncAllreduce
from .layers import ColumnParallelLinear from .layers import ColumnParallelLinear
from .layers import RowParallelLinear from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding from .layers import VocabParallelEmbedding
......
...@@ -199,15 +199,18 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -199,15 +199,18 @@ class VocabParallelEmbedding(torch.nn.Module):
return output return output
class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function): class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
""" """
Column-parallel linear layer execution with asynchronous all-reduce Linear layer execution with asynchronous all-reduce and gradient accumulation
execution in backprop. fusion in backprop.
""" """
@staticmethod @staticmethod
def forward(ctx, input, weight, bias): def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
async_grad_allreduce):
ctx.save_for_backward(input, weight) ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce
output = torch.matmul(input, weight.t()) output = torch.matmul(input, weight.t())
if bias is not None: if bias is not None:
output = output + bias output = output + bias
...@@ -215,19 +218,32 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function): ...@@ -215,19 +218,32 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
import fused_dense_cuda
input, weight = ctx.saved_tensors input, weight = ctx.saved_tensors
use_bias = ctx.use_bias use_bias = ctx.use_bias
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
# Asyncronous all-reduce
handle = torch.distributed.all_reduce( # Convert the tensor shapes to 2D for execution compatibility
grad_input, group=get_tensor_model_parallel_group(), async_op=True) grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1],
# Delay the start of weight gradient computation shortly (3us) to have grad_output.shape[2])
# all-reduce scheduled first and have GPU resources allocated input = input.view(input.shape[0] * input.shape[1], input.shape[2])
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(input) if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion:
fused_dense_cuda.wgrad_gemm_accum_fp32(input, grad_output, weight.main_grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(input)
grad_bias = grad_output.sum(dim=0) if use_bias else None grad_bias = grad_output.sum(dim=0) if use_bias else None
handle.wait() if ctx.async_grad_allreduce:
return grad_input, grad_weight, grad_bias handle.wait()
return grad_input, grad_weight, grad_bias, None, None
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(torch.nn.Module):
...@@ -240,7 +256,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -240,7 +256,7 @@ class ColumnParallelLinear(torch.nn.Module):
input_size: first dimension of matrix A. input_size: first dimension of matrix A.
output_size: second dimension of matrix A. output_size: second dimension of matrix A.
bias: If true, add bias bias: If true, add bias
gather_output: If true, call all-gather on output and make Y avaiable gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set init_method: method to initialize weights. Note that bias is always set
...@@ -305,29 +321,23 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -305,29 +321,23 @@ class ColumnParallelLinear(torch.nn.Module):
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.async_tensor_model_parallel_allreduce = ( self.async_tensor_model_parallel_allreduce = (
not args.no_async_tensor_model_parallel_allreduce and args.async_tensor_model_parallel_allreduce and
world_size > 1) world_size > 1)
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
def forward(self, input_): def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.async_tensor_model_parallel_allreduce: if self.async_tensor_model_parallel_allreduce:
input_shape = input_.shape input_parallel = input_
input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
# Maxtrix multiply with asynchronouse all-reduce execution
output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply(
input_, self.weight, bias)
output_parallel = output_parallel.view(
input_shape[0], input_shape[1], output_parallel.shape[1])
else: else:
# Set up backprop all-reduce. # Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_) input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
# Matrix multiply. output_parallel = LinearWithGradAccumulationAndAsyncAllreduce.apply(
output_parallel = F.linear(input_parallel, self.weight, bias) input_parallel, self.weight, bias, self.gradient_accumulation_fusion,
self.async_tensor_model_parallel_allreduce)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
...@@ -415,7 +425,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -415,7 +425,7 @@ class RowParallelLinear(torch.nn.Module):
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
def forward(self, input_): def forward(self, input_):
...@@ -425,7 +435,9 @@ class RowParallelLinear(torch.nn.Module): ...@@ -425,7 +435,9 @@ class RowParallelLinear(torch.nn.Module):
else: else:
input_parallel = scatter_to_tensor_model_parallel_region(input_) input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight) output_parallel = LinearWithGradAccumulationAndAsyncAllreduce.apply(
input_parallel, self.weight, None,
self.gradient_accumulation_fusion, None)
# All-reduce across all the partitions. # All-reduce across all the partitions.
output_ = reduce_from_tensor_model_parallel_region(output_parallel) output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add: if not self.skip_bias_add:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment