Commit 16c90445 authored by slym's avatar slym
Browse files

minor changes

parent cf7efd4f
......@@ -462,10 +462,8 @@ def _add_training_args(parser):
group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader')
group.add_argument('--async-tensor-parallel-allreduce', action='store_true',
help='Enable asynchronous excution of tensor-parallel allreduce '
'with other GPU operators',
dest='async_tensor_parallel_allreduce')
group.add_argument('--async-tensor-model-parallel-allreduce',
action='store_true')
return parser
......
......@@ -177,7 +177,7 @@ def _initialize_distributed():
args.local_rank = device
torch.cuda.set_device(device)
# Increase cuda stream priority of NCCL ops when overlapping with other ops
if (args.async_tensor_parallel_allreduce and
if (args.async_tensor_model_parallel_allreduce and
args.tensor_model_parallel_size > 1):
from torch._C._distributed_c10d import ProcessGroupNCCL
......
......@@ -199,7 +199,7 @@ class VocabParallelEmbedding(torch.nn.Module):
return output
class ColumnParallelLinearFunction(torch.autograd.Function):
class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
"""
Column-parallel linear layer execution with asynchronous all-reduce
execution in backprop.
......@@ -304,19 +304,19 @@ class ColumnParallelLinear(torch.nn.Module):
self.bias.zero_()
else:
self.register_parameter('bias', None)
self.async_tensor_parallel_allreduce = (args.async_tensor_parallel_allreduce
and world_size > 1)
self.async_tensor_model_parallel_allreduce = (
args.async_tensor_model_parallel_allreduce and world_size > 1)
def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
if self.async_tensor_parallel_allreduce:
if self.async_tensor_model_parallel_allreduce:
input_shape = input_.shape
input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
# Maxtrix multiply with asynchronouse tensor-parallel all-reduce execution
output_parallel = ColumnParallelLinearFunction.apply(
# Maxtrix multiply with asynchronouse all-reduce execution
output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply(
input_, self.weight, bias, bias is not None)
output_parallel = output_parallel.view(
input_shape[0], input_shape[1], output_parallel.shape[1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment