Unverified Commit 086402d5 authored by tmarkstrum's avatar tmarkstrum Committed by GitHub
Browse files

add toggler to disable the using the nccl base collectives (#799)

* add toggler to disable the using the nccl base collectives

* added todo to remove the toggle when the issue is resolved.
parent 180ab8c8
......@@ -9,6 +9,7 @@ from enum import Enum, auto
import functools
import logging
from math import inf
import os
import time
import traceback
import typing
......@@ -54,6 +55,11 @@ from . import fsdp_optim_utils as ou
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
# TODO: Remove the toggle here when github open issue #801 is resolved.
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
enable_nccl_base_collectives = False
else:
enable_nccl_base_collectives = True
class TrainingState(Enum):
......@@ -1599,7 +1605,7 @@ class FullyShardedDataParallel(nn.Module):
output_tensor = p._full_param_padded
# Fill output_tensor with (p.data for each shard in self.world_size)
if hasattr(dist, "_all_gather_base"):
if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives:
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
dist._all_gather_base(output_tensor, p_data, group=self.process_group)
else:
......
......@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import functools
import os
from typing import Callable, Dict, List, Optional, Tuple
import torch
......@@ -11,6 +12,12 @@ from torch import Tensor
import torch.distributed as dist
from torch.distributed import ProcessGroup
# TODO: Remove the toggle-enable_nccl_base_collectives when github open issue #801 is resolved.
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
enable_nccl_base_collectives = False
else:
enable_nccl_base_collectives = True
class Bucket:
def __init__(self, data: Tensor, group: ProcessGroup):
......@@ -26,7 +33,7 @@ class Bucket:
assert len(self.callbacks) == 0
return
# reduce-scatter bucket
if hasattr(dist, "_reduce_scatter_base"):
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
dist._reduce_scatter_base(
self.output_shard[: self.offset], self.data[:, : self.offset].contiguous(), group=self.group
)
......@@ -130,7 +137,7 @@ class ReduceScatterBucketer:
# TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors)
# input is too big to fit in the bucket, reduce-scatter directly
output = torch.zeros_like(input_list[0])
if hasattr(dist, "_reduce_scatter_base"):
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
input_flattened = torch.cat(input_list)
dist._reduce_scatter_base(output, input_flattened, group=group)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment