Commit 16c90445 authored by slym's avatar slym
Browse files

minor changes

parent cf7efd4f
...@@ -462,10 +462,8 @@ def _add_training_args(parser): ...@@ -462,10 +462,8 @@ def _add_training_args(parser):
group.add_argument('--dataloader-type', type=str, default=None, group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic'], choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader') help='Single pass vs multiple pass data loader')
group.add_argument('--async-tensor-parallel-allreduce', action='store_true', group.add_argument('--async-tensor-model-parallel-allreduce',
help='Enable asynchronous excution of tensor-parallel allreduce ' action='store_true')
'with other GPU operators',
dest='async_tensor_parallel_allreduce')
return parser return parser
......
...@@ -177,7 +177,7 @@ def _initialize_distributed(): ...@@ -177,7 +177,7 @@ def _initialize_distributed():
args.local_rank = device args.local_rank = device
torch.cuda.set_device(device) torch.cuda.set_device(device)
# Increase cuda stream priority of NCCL ops when overlapping with other ops # Increase cuda stream priority of NCCL ops when overlapping with other ops
if (args.async_tensor_parallel_allreduce and if (args.async_tensor_model_parallel_allreduce and
args.tensor_model_parallel_size > 1): args.tensor_model_parallel_size > 1):
from torch._C._distributed_c10d import ProcessGroupNCCL from torch._C._distributed_c10d import ProcessGroupNCCL
......
...@@ -199,7 +199,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -199,7 +199,7 @@ class VocabParallelEmbedding(torch.nn.Module):
return output return output
class ColumnParallelLinearFunction(torch.autograd.Function): class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
""" """
Column-parallel linear layer execution with asynchronous all-reduce Column-parallel linear layer execution with asynchronous all-reduce
execution in backprop. execution in backprop.
...@@ -304,19 +304,19 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -304,19 +304,19 @@ class ColumnParallelLinear(torch.nn.Module):
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.async_tensor_parallel_allreduce = (args.async_tensor_parallel_allreduce self.async_tensor_model_parallel_allreduce = (
and world_size > 1) args.async_tensor_model_parallel_allreduce and world_size > 1)
def forward(self, input_): def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.async_tensor_parallel_allreduce: if self.async_tensor_model_parallel_allreduce:
input_shape = input_.shape input_shape = input_.shape
input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2]) input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
# Maxtrix multiply with asynchronouse tensor-parallel all-reduce execution # Maxtrix multiply with asynchronouse all-reduce execution
output_parallel = ColumnParallelLinearFunction.apply( output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply(
input_, self.weight, bias, bias is not None) input_, self.weight, bias, bias is not None)
output_parallel = output_parallel.view( output_parallel = output_parallel.view(
input_shape[0], input_shape[1], output_parallel.shape[1]) input_shape[0], input_shape[1], output_parallel.shape[1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment