Unverified Commit 086402d5 authored by tmarkstrum's avatar tmarkstrum Committed by GitHub
Browse files

add toggler to disable the using the nccl base collectives (#799)

* add toggler to disable the using the nccl base collectives

* added todo to remove the toggle when the issue is resolved.
parent 180ab8c8
...@@ -9,6 +9,7 @@ from enum import Enum, auto ...@@ -9,6 +9,7 @@ from enum import Enum, auto
import functools import functools
import logging import logging
from math import inf from math import inf
import os
import time import time
import traceback import traceback
import typing import typing
...@@ -54,6 +55,11 @@ from . import fsdp_optim_utils as ou ...@@ -54,6 +55,11 @@ from . import fsdp_optim_utils as ou
if TYPE_CHECKING: if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401 from collections import OrderedDict # noqa: F401
# TODO: Remove the toggle here when github open issue #801 is resolved.
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
enable_nccl_base_collectives = False
else:
enable_nccl_base_collectives = True
class TrainingState(Enum): class TrainingState(Enum):
...@@ -1599,7 +1605,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1599,7 +1605,7 @@ class FullyShardedDataParallel(nn.Module):
output_tensor = p._full_param_padded output_tensor = p._full_param_padded
# Fill output_tensor with (p.data for each shard in self.world_size) # Fill output_tensor with (p.data for each shard in self.world_size)
if hasattr(dist, "_all_gather_base"): if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives:
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather. # New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
dist._all_gather_base(output_tensor, p_data, group=self.process_group) dist._all_gather_base(output_tensor, p_data, group=self.process_group)
else: else:
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import functools import functools
import os
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Optional, Tuple
import torch import torch
...@@ -11,6 +12,12 @@ from torch import Tensor ...@@ -11,6 +12,12 @@ from torch import Tensor
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
# TODO: Remove the toggle-enable_nccl_base_collectives when github open issue #801 is resolved.
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
enable_nccl_base_collectives = False
else:
enable_nccl_base_collectives = True
class Bucket: class Bucket:
def __init__(self, data: Tensor, group: ProcessGroup): def __init__(self, data: Tensor, group: ProcessGroup):
...@@ -26,7 +33,7 @@ class Bucket: ...@@ -26,7 +33,7 @@ class Bucket:
assert len(self.callbacks) == 0 assert len(self.callbacks) == 0
return return
# reduce-scatter bucket # reduce-scatter bucket
if hasattr(dist, "_reduce_scatter_base"): if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
dist._reduce_scatter_base( dist._reduce_scatter_base(
self.output_shard[: self.offset], self.data[:, : self.offset].contiguous(), group=self.group self.output_shard[: self.offset], self.data[:, : self.offset].contiguous(), group=self.group
) )
...@@ -130,7 +137,7 @@ class ReduceScatterBucketer: ...@@ -130,7 +137,7 @@ class ReduceScatterBucketer:
# TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors) # TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors)
# input is too big to fit in the bucket, reduce-scatter directly # input is too big to fit in the bucket, reduce-scatter directly
output = torch.zeros_like(input_list[0]) output = torch.zeros_like(input_list[0])
if hasattr(dist, "_reduce_scatter_base"): if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
input_flattened = torch.cat(input_list) input_flattened = torch.cat(input_list)
dist._reduce_scatter_base(output, input_flattened, group=group) dist._reduce_scatter_base(output, input_flattened, group=group)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment