Commit cf7efd4f authored by Sangkug Lym's avatar Sangkug Lym Committed by slym
Browse files

t # This is a combination of 2 commits.

allreduce overlap with wgrad gemm

change custom delay to dummy add
parent b7ae685f
......@@ -462,6 +462,10 @@ def _add_training_args(parser):
group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader')
group.add_argument('--async-tensor-parallel-allreduce', action='store_true',
help='Enable asynchronous excution of tensor-parallel allreduce '
'with other GPU operators',
dest='async_tensor_parallel_allreduce')
return parser
......
......@@ -176,11 +176,22 @@ def _initialize_distributed():
else:
args.local_rank = device
torch.cuda.set_device(device)
# Increase cuda stream priority of NCCL ops when overlapping with other ops
if (args.async_tensor_parallel_allreduce and
args.tensor_model_parallel_size > 1):
from torch._C._distributed_c10d import ProcessGroupNCCL
pg_options = ProcessGroupNCCL.Options()
pg_options.is_high_priority_stream = True
pg_options._timeout = timedelta(days=7)
else:
pg_options = None
# Call the init process
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
timeout=timedelta(days=7))
timeout=timedelta(days=7),
pg_options=pg_options)
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
......
......@@ -27,6 +27,7 @@ from torch.nn.parameter import Parameter
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
from .initialize import get_tensor_model_parallel_group
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
......@@ -198,6 +199,37 @@ class VocabParallelEmbedding(torch.nn.Module):
return output
class ColumnParallelLinearFunction(torch.autograd.Function):
"""
Column-parallel linear layer execution with asynchronous all-reduce
execution in backprop.
"""
@staticmethod
def forward(ctx, input, weight, bias, use_bias):
ctx.save_for_backward(input, weight)
ctx.use_bias = use_bias
output = torch.matmul(input, weight.t())
if use_bias:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
grad_input = grad_output.matmul(weight)
# Asyncronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
handle.wait()
return grad_input, grad_weight, grad_bias, None
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.
......@@ -272,16 +304,29 @@ class ColumnParallelLinear(torch.nn.Module):
self.bias.zero_()
else:
self.register_parameter('bias', None)
self.async_tensor_parallel_allreduce = (args.async_tensor_parallel_allreduce
and world_size > 1)
def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
if self.async_tensor_parallel_allreduce:
input_shape = input_.shape
input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
# Maxtrix multiply with asynchronouse tensor-parallel all-reduce execution
output_parallel = ColumnParallelLinearFunction.apply(
input_, self.weight, bias, bias is not None)
output_parallel = output_parallel.view(
input_shape[0], input_shape[1], output_parallel.shape[1])
else:
# Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment