Commit 6e1bde1e authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'slym/allreduce_custom' into 'main'

allreduce overlap with wgrad gemm

See merge request ADLR/megatron-lm!316
parents b7ae685f 4df8b7a2
......@@ -462,6 +462,11 @@ def _add_training_args(parser):
group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader')
group.add_argument('--no-async-tensor-model-parallel-allreduce',
action='store_true',
help='Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.')
return parser
......
......@@ -176,11 +176,22 @@ def _initialize_distributed():
else:
args.local_rank = device
torch.cuda.set_device(device)
# Increase cuda stream priority of NCCL ops when overlapping with other ops
if (not args.no_async_tensor_model_parallel_allreduce and
args.tensor_model_parallel_size > 1):
from torch._C._distributed_c10d import ProcessGroupNCCL
pg_options = ProcessGroupNCCL.Options()
pg_options.is_high_priority_stream = True
pg_options._timeout = timedelta(days=7)
else:
pg_options = None
# Call the init process
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
timeout=timedelta(days=7))
timeout=timedelta(days=7),
pg_options=pg_options)
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
......
......@@ -27,6 +27,7 @@ from torch.nn.parameter import Parameter
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
from .initialize import get_tensor_model_parallel_group
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
......@@ -198,6 +199,37 @@ class VocabParallelEmbedding(torch.nn.Module):
return output
class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
"""
Column-parallel linear layer execution with asynchronous all-reduce
execution in backprop.
"""
@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
output = torch.matmul(input, weight.t())
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
grad_input = grad_output.matmul(weight)
# Asyncronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
handle.wait()
return grad_input, grad_weight, grad_bias
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.
......@@ -272,16 +304,30 @@ class ColumnParallelLinear(torch.nn.Module):
self.bias.zero_()
else:
self.register_parameter('bias', None)
self.async_tensor_model_parallel_allreduce = (
not args.no_async_tensor_model_parallel_allreduce and
world_size > 1)
def forward(self, input_):
# Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.async_tensor_model_parallel_allreduce:
input_shape = input_.shape
input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
# Maxtrix multiply with asynchronouse all-reduce execution
output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply(
input_, self.weight, bias)
output_parallel = output_parallel.view(
input_shape[0], input_shape[1], output_parallel.shape[1])
else:
# Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment