Commit 3f652469 authored by slym's avatar slym
Browse files

reflect feedback

parent 16c90445
......@@ -462,8 +462,11 @@ def _add_training_args(parser):
group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader')
group.add_argument('--async-tensor-model-parallel-allreduce',
action='store_true')
group.add_argument('--no-async-tensor-model-parallel-allreduce',
action='store_true',
help='Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.')
return parser
......
......@@ -177,7 +177,7 @@ def _initialize_distributed():
args.local_rank = device
torch.cuda.set_device(device)
# Increase cuda stream priority of NCCL ops when overlapping with other ops
if (args.async_tensor_model_parallel_allreduce and
if (not args.no_async_tensor_model_parallel_allreduce and
args.tensor_model_parallel_size > 1):
from torch._C._distributed_c10d import ProcessGroupNCCL
......
......@@ -305,7 +305,8 @@ class ColumnParallelLinear(torch.nn.Module):
else:
self.register_parameter('bias', None)
self.async_tensor_model_parallel_allreduce = (
args.async_tensor_model_parallel_allreduce and world_size > 1)
not args.no_async_tensor_model_parallel_allreduce and
world_size > 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment