Commit 6fdbf26b authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

Merge branch 'main' into allreduce-split

parents 0d77c0e9 53f3efc4
...@@ -172,6 +172,14 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -172,6 +172,14 @@ def parse_args(extra_args_provider=None, defaults={},
if args.accumulate_allreduce_grads_in_fp32: if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
assert args.use_contiguous_buffers_in_local_ddp assert args.use_contiguous_buffers_in_local_ddp
else:
if args.gradient_accumulation_fusion:
args.gradient_accumulation_fusion = False
if args.rank == 0:
print('Gradient accumulation fusion to linear layer weight '
'gradient computation is supported only with fp32 '
'gradient accumulation. Setting gradient_accumulation_fusion '
'to False', flush=True)
# For torch DDP, we do not use contiguous buffer # For torch DDP, we do not use contiguous buffer
if args.DDP_impl == 'torch': if args.DDP_impl == 'torch':
...@@ -357,7 +365,8 @@ def _add_network_size_args(parser): ...@@ -357,7 +365,8 @@ def _add_network_size_args(parser):
group.add_argument('--bert-no-binary-head', action='store_false', group.add_argument('--bert-no-binary-head', action='store_false',
help='Disable BERT binary head.', help='Disable BERT binary head.',
dest='bert_binary_head') dest='bert_binary_head')
group.add_argument('--num-experts', type=int, default=None,
help='Number of Experts in Switch Transformer (None means no Switch)')
return parser return parser
...@@ -521,10 +530,11 @@ def _add_training_args(parser): ...@@ -521,10 +530,11 @@ def _add_training_args(parser):
choices=['single', 'cyclic'], choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader') help='Single pass vs multiple pass data loader')
group.add_argument('--no-async-tensor-model-parallel-allreduce', group.add_argument('--no-async-tensor-model-parallel-allreduce',
action='store_true', action='store_false',
help='Disable asynchronous execution of ' help='Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight ' 'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.') 'gradient compuation of a column-linear layer.',
dest='async_tensor_model_parallel_allreduce')
group.add_argument('--no-persist-layer-norm', action='store_true', group.add_argument('--no-persist-layer-norm', action='store_true',
help='Disable using persistent fused layer norm kernel. ' help='Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please ' 'This kernel supports only a set of hidden sizes. Please '
...@@ -532,8 +542,11 @@ def _add_training_args(parser): ...@@ -532,8 +542,11 @@ def _add_training_args(parser):
'size is supported.') 'size is supported.')
group.add_argument('--model-parallel-memory-opt', action='store_true', group.add_argument('--model-parallel-memory-opt', action='store_true',
help='Enable model parallel memory optmization.') help='Enable model parallel memory optmization.')
group.add_argument('--no-gradient-accumulation-fusion',
action='store_false',
help='Disable fusing gradient accumulation to weight '
'gradient computation of linear layers',
dest='gradient_accumulation_fusion')
return parser return parser
......
...@@ -94,6 +94,16 @@ def load(args): ...@@ -94,6 +94,16 @@ def load(args):
fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper( fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper(
"fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags) "fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags)
# =================================
# Fused gradient accumulation to weight gradient computation of linear layer
# =================================
if args.gradient_accumulation_fusion:
sources=[srcpath / 'fused_weight_gradient_dense.cpp',
srcpath / 'fused_weight_gradient_dense.cu']
fused_dense_cuda = _cpp_extention_load_helper(
"fused_dense_cuda", sources, [])
def _get_cuda_bare_metal_version(cuda_dir): def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
......
#include <torch/torch.h>
#include <torch/extension.h>
#include <vector>
#include <stdio.h>
#include "type_shim.h"
template <typename T>
int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
void wgrad_gemm_accum_fp32(const at::Tensor input, const at::Tensor d_output, at::Tensor d_weight) {
at::Tensor input_2d, d_output_2d;
// input tensor: collapse to the first dim
auto in_sizes = input.sizes();
if (input.dim() > 2) {
input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});
} else {
input_2d = input;
}
// d_output tensor: collapse to the first dim
auto d_out_sizes = d_output.sizes();
if (d_output.dim() > 2) {
d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});
} else {
d_output_2d = d_output;
}
int hidden_dim = input_2d.size(0);
int in_dim = input_2d.size(1);
int out_dim = d_weight.size(0);
DISPATCH_HALF_BFLOAT_AND_FLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp32",
int result = wgrad_gemm_accum_fp32_cuda<scalar_t>(
input_2d.data_ptr<scalar_t>(),
d_output_2d.data_ptr<scalar_t>(),
d_weight.data_ptr<float>(),
in_dim,
hidden_dim,
out_dim);
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32, "wgrad gemm accum in fp32");
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <torch/torch.h>
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
// BF16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::BFloat16* A,
int lda,
at::BFloat16* B,
int ldb,
const float* beta,
float* C,
int ldc) {
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16BF,
lda,
B,
CUDA_R_16BF,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
// FP16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::Half* A,
int lda,
at::Half* B,
int ldb,
const float* beta,
float* C,
int ldc) {
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
// FP32 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
float* A,
int lda,
float* B,
int ldb,
const float* beta,
float* C,
int ldc) {
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_32F,
lda,
B,
CUDA_R_32F,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
template <typename T>
int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream;
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta = 1.0;
int status = 1;
status = gemmex_wrapper(
handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
in_dim,
out_dim,
hidden_dim,
&alpha,
input,
in_dim,
d_output,
out_dim,
&beta,
d_weight,
in_dim);
return status;
}
template int wgrad_gemm_accum_fp32_cuda<at::Half>(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
template int wgrad_gemm_accum_fp32_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
template int wgrad_gemm_accum_fp32_cuda<float>(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
...@@ -39,6 +39,32 @@ ...@@ -39,6 +39,32 @@
} }
#define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \ switch(TYPEIN) \
......
...@@ -164,13 +164,13 @@ class DistributedDataParallel(DistributedDataParallelBase): ...@@ -164,13 +164,13 @@ class DistributedDataParallel(DistributedDataParallelBase):
grad_acc.register_hook(self._make_param_hook(param)) grad_acc.register_hook(self._make_param_hook(param))
self.grad_accs.append(grad_acc) self.grad_accs.append(grad_acc)
def _make_param_hook(self, param): def _make_param_hook(self, param):
"""Create the all-reduce hook for backprop.""" """Create the all-reduce hook for backprop."""
# Hook used for back-prop. # Hook used for back-prop.
def param_hook(*unused): def param_hook(*unused):
# Add the gradient to the buffer. # Add the gradient to the buffer.
if param.grad.data is not None: if param.grad is not None:
# The gradient function of linear layers is fused with GEMMs
param.main_grad.add_(param.grad.data) param.main_grad.add_(param.grad.data)
# Now we can deallocate grad memory. # Now we can deallocate grad memory.
param.grad = None param.grad = None
......
...@@ -31,7 +31,6 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -31,7 +31,6 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=None): bias=None):
"""LM logits using word embedding weights.""" """LM logits using word embedding weights."""
args = get_args() args = get_args()
# Parallel logits. # Parallel logits.
if args.async_tensor_model_parallel_allreduce or\ if args.async_tensor_model_parallel_allreduce or\
args.model_parallel_memory_opt: args.model_parallel_memory_opt:
......
...@@ -116,6 +116,53 @@ class ParallelMLP(MegatronModule): ...@@ -116,6 +116,53 @@ class ParallelMLP(MegatronModule):
output, output_bias = self.dense_4h_to_h(intermediate_parallel) output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias return output, output_bias
class SwitchMLP(MegatronModule):
"""
Routes input to one of N MLP "experts"
"""
def __init__(self, init_method, output_layer_init_method):
super(SwitchMLP, self).__init__()
args = get_args()
self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
self.experts = torch.nn.ModuleList()
for i in range(args.num_experts):
self.experts.append(ParallelMLP(init_method, output_layer_init_method))
def forward(self, hidden_states):
# hidden_states: [b, s, h]
b = hidden_states.size(0)
s = hidden_states.size(1)
h = hidden_states.size(2)
route = self.router(hidden_states)
route = torch.nn.functional.softmax(route, dim=2)
max_prob, max_ind = torch.max(route, dim=2)
max_prob = torch.unsqueeze(max_prob, 2) # [b s 1]
# TODO (rprenger) TODO this could be made easier to read
# Converting [b, s, h] to [b*s, h].
# Each vector could be routed differently
hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [b*s h]
max_prob = max_prob.view(-1, max_prob.size(2)) # [b*s 1]
max_ind = max_ind.view(-1) # [b*s]
output_total = torch.empty_like(hidden_states)
output_bias_total = torch.empty_like(hidden_states)
#TODO (rprenger) This does each expert in serial, but it could be parallelized
for expert_num, expert in enumerate(self.experts):
local_indices = (max_ind == expert_num).nonzero()
hidden = hidden_states[local_indices,:]
output, output_bias = expert(hidden)
output_bias = output_bias.expand_as(output)
output_total[local_indices,:] = output
output_bias_total[local_indices,:] = output_bias
output_total = output_total*max_prob
output_bias_total = output_bias_total*max_prob
output_total = output_total.view(b, s, h)
output_bias_total = output_bias_total.view(b, s, h)
return output_total, output_bias_total
class ParallelAttention(MegatronModule): class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class. """Parallel self-attention layer abstract class.
...@@ -482,8 +529,10 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -482,8 +529,10 @@ class ParallelTransformerLayer(MegatronModule):
sequence_parallel=args.model_parallel_memory_opt) sequence_parallel=args.model_parallel_memory_opt)
# MLP # MLP
self.mlp = ParallelMLP(init_method, if args.num_experts is not None:
output_layer_init_method) self.mlp = SwitchMLP(init_method, output_layer_init_method)
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None, encoder_output=None, enc_dec_attn_mask=None,
......
...@@ -49,6 +49,7 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi ...@@ -49,6 +49,7 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
from .initialize import initialize_model_parallel from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized from .initialize import model_parallel_is_initialized
from .layers import LinearWithGradAccumulationAndAsyncAllreduce
from .layers import ColumnParallelLinear from .layers import ColumnParallelLinear
from .layers import RowParallelLinear from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding from .layers import VocabParallelEmbedding
......
...@@ -237,6 +237,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -237,6 +237,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
import fused_dense_cuda
input, weight = ctx.saved_tensors input, weight = ctx.saved_tensors
use_bias = ctx.use_bias use_bias = ctx.use_bias
...@@ -280,12 +281,10 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -280,12 +281,10 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
sub_grad_input = torch.empty(dim_size, dtype=input.dtype, sub_grad_input = torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False) requires_grad=False)
# reduce_scatter # reduce_scatter
handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input, handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input,
group=get_tensor_model_parallel_group(), group=get_tensor_model_parallel_group(),
async_op=True) async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have # Delay the start of weight gradient computation shortly (3us) to have
# reduce scatter scheduled first and have GPU resources allocated # reduce scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1 _ = torch.empty(1, device=grad_output.device) + 1
...@@ -298,14 +297,15 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -298,14 +297,15 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
grad_weight = grad_output.t().matmul(total_input) grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreducei:
handle.wait()
return grad_input, grad_weight, grad_bias
if ctx.model_parallel_memory_opt: if ctx.model_parallel_memory_opt:
handle.wait() handle.wait()
return sub_grad_input, grad_weight, grad_bias return sub_grad_input, grad_weight, grad_bias
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
...@@ -317,7 +317,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -317,7 +317,7 @@ class ColumnParallelLinear(torch.nn.Module):
input_size: first dimension of matrix A. input_size: first dimension of matrix A.
output_size: second dimension of matrix A. output_size: second dimension of matrix A.
bias: If true, add bias bias: If true, add bias
gather_output: If true, call all-gather on output and make Y avaiable gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set init_method: method to initialize weights. Note that bias is always set
...@@ -382,13 +382,14 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -382,13 +382,14 @@ class ColumnParallelLinear(torch.nn.Module):
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.async_tensor_model_parallel_allreduce = ( self.async_tensor_model_parallel_allreduce = (
not args.no_async_tensor_model_parallel_allreduce and args.async_tensor_model_parallel_allreduce and
world_size > 1) world_size > 1)
self.model_parallel_memory_opt = ( self.model_parallel_memory_opt = (
args.model_parallel_memory_opt and args.model_parallel_memory_opt and
world_size > 1) world_size > 1)
assert not self.async_tensor_model_parallel_allreduce or \ assert not self.async_tensor_model_parallel_allreduce or \
not self.model_parallel_memory_opt not self.model_parallel_memory_opt
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
def forward(self, input_): def forward(self, input_):
...@@ -491,8 +492,8 @@ class RowParallelLinear(torch.nn.Module): ...@@ -491,8 +492,8 @@ class RowParallelLinear(torch.nn.Module):
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.model_parallel_memory_opt = args.model_parallel_memory_opt self.model_parallel_memory_opt = args.model_parallel_memory_opt
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
def forward(self, input_): def forward(self, input_):
...@@ -503,7 +504,9 @@ class RowParallelLinear(torch.nn.Module): ...@@ -503,7 +504,9 @@ class RowParallelLinear(torch.nn.Module):
assert not self.model_parallel_memory_opt assert not self.model_parallel_memory_opt
input_parallel = scatter_to_tensor_model_parallel_region(input_) input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight) output_parallel = LinearWithGradAccumulationAndAsyncAllreduce.apply(
input_parallel, self.weight, None,
self.gradient_accumulation_fusion, None)
# All-reduce across all the partitions. # All-reduce across all the partitions.
if self.model_parallel_memory_opt: if self.model_parallel_memory_opt:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment