Commit 6fdbf26b authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

Merge branch 'main' into allreduce-split

parents 0d77c0e9 53f3efc4
......@@ -172,6 +172,14 @@ def parse_args(extra_args_provider=None, defaults={},
if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local'
assert args.use_contiguous_buffers_in_local_ddp
else:
if args.gradient_accumulation_fusion:
args.gradient_accumulation_fusion = False
if args.rank == 0:
print('Gradient accumulation fusion to linear layer weight '
'gradient computation is supported only with fp32 '
'gradient accumulation. Setting gradient_accumulation_fusion '
'to False', flush=True)
# For torch DDP, we do not use contiguous buffer
if args.DDP_impl == 'torch':
......@@ -357,7 +365,8 @@ def _add_network_size_args(parser):
group.add_argument('--bert-no-binary-head', action='store_false',
help='Disable BERT binary head.',
dest='bert_binary_head')
group.add_argument('--num-experts', type=int, default=None,
help='Number of Experts in Switch Transformer (None means no Switch)')
return parser
......@@ -521,10 +530,11 @@ def _add_training_args(parser):
choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader')
group.add_argument('--no-async-tensor-model-parallel-allreduce',
action='store_true',
action='store_false',
help='Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.')
'gradient compuation of a column-linear layer.',
dest='async_tensor_model_parallel_allreduce')
group.add_argument('--no-persist-layer-norm', action='store_true',
help='Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please '
......@@ -532,8 +542,11 @@ def _add_training_args(parser):
'size is supported.')
group.add_argument('--model-parallel-memory-opt', action='store_true',
help='Enable model parallel memory optmization.')
group.add_argument('--no-gradient-accumulation-fusion',
action='store_false',
help='Disable fusing gradient accumulation to weight '
'gradient computation of linear layers',
dest='gradient_accumulation_fusion')
return parser
......
......@@ -94,6 +94,16 @@ def load(args):
fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper(
"fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags)
# =================================
# Fused gradient accumulation to weight gradient computation of linear layer
# =================================
if args.gradient_accumulation_fusion:
sources=[srcpath / 'fused_weight_gradient_dense.cpp',
srcpath / 'fused_weight_gradient_dense.cu']
fused_dense_cuda = _cpp_extention_load_helper(
"fused_dense_cuda", sources, [])
def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
......
#include <torch/torch.h>
#include <torch/extension.h>
#include <vector>
#include <stdio.h>
#include "type_shim.h"
template <typename T>
int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
void wgrad_gemm_accum_fp32(const at::Tensor input, const at::Tensor d_output, at::Tensor d_weight) {
at::Tensor input_2d, d_output_2d;
// input tensor: collapse to the first dim
auto in_sizes = input.sizes();
if (input.dim() > 2) {
input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});
} else {
input_2d = input;
}
// d_output tensor: collapse to the first dim
auto d_out_sizes = d_output.sizes();
if (d_output.dim() > 2) {
d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});
} else {
d_output_2d = d_output;
}
int hidden_dim = input_2d.size(0);
int in_dim = input_2d.size(1);
int out_dim = d_weight.size(0);
DISPATCH_HALF_BFLOAT_AND_FLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp32",
int result = wgrad_gemm_accum_fp32_cuda<scalar_t>(
input_2d.data_ptr<scalar_t>(),
d_output_2d.data_ptr<scalar_t>(),
d_weight.data_ptr<float>(),
in_dim,
hidden_dim,
out_dim);
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32, "wgrad gemm accum in fp32");
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <torch/torch.h>
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
// BF16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::BFloat16* A,
int lda,
at::BFloat16* B,
int ldb,
const float* beta,
float* C,
int ldc) {
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16BF,
lda,
B,
CUDA_R_16BF,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
// FP16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::Half* A,
int lda,
at::Half* B,
int ldb,
const float* beta,
float* C,
int ldc) {
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
// FP32 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
float* A,
int lda,
float* B,
int ldb,
const float* beta,
float* C,
int ldc) {
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_32F,
lda,
B,
CUDA_R_32F,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
template <typename T>
int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream;
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta = 1.0;
int status = 1;
status = gemmex_wrapper(
handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
in_dim,
out_dim,
hidden_dim,
&alpha,
input,
in_dim,
d_output,
out_dim,
&beta,
d_weight,
in_dim);
return status;
}
template int wgrad_gemm_accum_fp32_cuda<at::Half>(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
template int wgrad_gemm_accum_fp32_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
template int wgrad_gemm_accum_fp32_cuda<float>(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
......@@ -39,6 +39,32 @@
}
#define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
......
......@@ -164,13 +164,13 @@ class DistributedDataParallel(DistributedDataParallelBase):
grad_acc.register_hook(self._make_param_hook(param))
self.grad_accs.append(grad_acc)
def _make_param_hook(self, param):
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
def param_hook(*unused):
# Add the gradient to the buffer.
if param.grad.data is not None:
if param.grad is not None:
# The gradient function of linear layers is fused with GEMMs
param.main_grad.add_(param.grad.data)
# Now we can deallocate grad memory.
param.grad = None
......
......@@ -31,7 +31,6 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=None):
"""LM logits using word embedding weights."""
args = get_args()
# Parallel logits.
if args.async_tensor_model_parallel_allreduce or\
args.model_parallel_memory_opt:
......
......@@ -116,6 +116,53 @@ class ParallelMLP(MegatronModule):
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias
class SwitchMLP(MegatronModule):
"""
Routes input to one of N MLP "experts"
"""
def __init__(self, init_method, output_layer_init_method):
super(SwitchMLP, self).__init__()
args = get_args()
self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
self.experts = torch.nn.ModuleList()
for i in range(args.num_experts):
self.experts.append(ParallelMLP(init_method, output_layer_init_method))
def forward(self, hidden_states):
# hidden_states: [b, s, h]
b = hidden_states.size(0)
s = hidden_states.size(1)
h = hidden_states.size(2)
route = self.router(hidden_states)
route = torch.nn.functional.softmax(route, dim=2)
max_prob, max_ind = torch.max(route, dim=2)
max_prob = torch.unsqueeze(max_prob, 2) # [b s 1]
# TODO (rprenger) TODO this could be made easier to read
# Converting [b, s, h] to [b*s, h].
# Each vector could be routed differently
hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [b*s h]
max_prob = max_prob.view(-1, max_prob.size(2)) # [b*s 1]
max_ind = max_ind.view(-1) # [b*s]
output_total = torch.empty_like(hidden_states)
output_bias_total = torch.empty_like(hidden_states)
#TODO (rprenger) This does each expert in serial, but it could be parallelized
for expert_num, expert in enumerate(self.experts):
local_indices = (max_ind == expert_num).nonzero()
hidden = hidden_states[local_indices,:]
output, output_bias = expert(hidden)
output_bias = output_bias.expand_as(output)
output_total[local_indices,:] = output
output_bias_total[local_indices,:] = output_bias
output_total = output_total*max_prob
output_bias_total = output_bias_total*max_prob
output_total = output_total.view(b, s, h)
output_bias_total = output_bias_total.view(b, s, h)
return output_total, output_bias_total
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
......@@ -482,8 +529,10 @@ class ParallelTransformerLayer(MegatronModule):
sequence_parallel=args.model_parallel_memory_opt)
# MLP
self.mlp = ParallelMLP(init_method,
output_layer_init_method)
if args.num_experts is not None:
self.mlp = SwitchMLP(init_method, output_layer_init_method)
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
......
......@@ -49,6 +49,7 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized
from .layers import LinearWithGradAccumulationAndAsyncAllreduce
from .layers import ColumnParallelLinear
from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding
......
......@@ -237,6 +237,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
import fused_dense_cuda
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
......@@ -280,12 +281,10 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
sub_grad_input = torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# reduce_scatter
handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input,
group=get_tensor_model_parallel_group(),
async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
......@@ -298,14 +297,15 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreducei:
handle.wait()
return grad_input, grad_weight, grad_bias
if ctx.model_parallel_memory_opt:
handle.wait()
return sub_grad_input, grad_weight, grad_bias
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.
......@@ -317,7 +317,7 @@ class ColumnParallelLinear(torch.nn.Module):
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias
gather_output: If true, call all-gather on output and make Y avaiable
gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set
......@@ -382,13 +382,14 @@ class ColumnParallelLinear(torch.nn.Module):
else:
self.register_parameter('bias', None)
self.async_tensor_model_parallel_allreduce = (
not args.no_async_tensor_model_parallel_allreduce and
args.async_tensor_model_parallel_allreduce and
world_size > 1)
self.model_parallel_memory_opt = (
args.model_parallel_memory_opt and
world_size > 1)
assert not self.async_tensor_model_parallel_allreduce or \
not self.model_parallel_memory_opt
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
def forward(self, input_):
......@@ -491,8 +492,8 @@ class RowParallelLinear(torch.nn.Module):
self.bias.zero_()
else:
self.register_parameter('bias', None)
self.model_parallel_memory_opt = args.model_parallel_memory_opt
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
def forward(self, input_):
......@@ -503,7 +504,9 @@ class RowParallelLinear(torch.nn.Module):
assert not self.model_parallel_memory_opt
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight)
output_parallel = LinearWithGradAccumulationAndAsyncAllreduce.apply(
input_parallel, self.weight, None,
self.gradient_accumulation_fusion, None)
# All-reduce across all the partitions.
if self.model_parallel_memory_opt:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment