Unverified Commit 47ce21ac authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] don't import ProcessGroup eagerly (#1074)



* [fix] don't import ProcessGroup eagerly

- move the import into typing since it is only used for type checking
- fixes #1057

* more fixes

* one more

* tested at least
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent d8fc94d9
...@@ -7,13 +7,16 @@ ...@@ -7,13 +7,16 @@
from enum import Enum from enum import Enum
import sys import sys
from typing import List, Optional, Sequence from typing import TYPE_CHECKING, List, Optional, Sequence
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn.functional as F import torch.nn.functional as F
if TYPE_CHECKING:
# See comments in FSDP code for reason of this import.
from torch.distributed import ProcessGroup
def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]: def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
"""Chunk a given Tensor into num_chunks parts and add any necessary padding.""" """Chunk a given Tensor into num_chunks parts and add any necessary padding."""
...@@ -27,7 +30,7 @@ def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]: ...@@ -27,7 +30,7 @@ def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
return chunks return chunks
def validate_process_group(device: torch.device, process_group: ProcessGroup) -> None: def validate_process_group(device: torch.device, process_group: "ProcessGroup") -> None:
"""Do a quick test in case user called FSDP without calling torch.cuda.set_device() """Do a quick test in case user called FSDP without calling torch.cuda.set_device()
correctly. This can easily happen in cpu_offload case where the model resides on correctly. This can easily happen in cpu_offload case where the model resides on
the CPU. the CPU.
...@@ -67,7 +70,7 @@ class ProcessGroupName(str, Enum): ...@@ -67,7 +70,7 @@ class ProcessGroupName(str, Enum):
def get_process_group_cached( def get_process_group_cached(
name: ProcessGroupName = ProcessGroupName.default, ranks: Optional[Sequence[int]] = None name: ProcessGroupName = ProcessGroupName.default, ranks: Optional[Sequence[int]] = None
) -> ProcessGroup: ) -> "ProcessGroup":
""" """
Singleton PyTorch distributed group cache. Inspired by the code from fairseq. Singleton PyTorch distributed group cache. Inspired by the code from fairseq.
......
...@@ -5,12 +5,14 @@ ...@@ -5,12 +5,14 @@
import functools import functools
import os import os
from typing import Callable, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup
if TYPE_CHECKING:
from torch.distributed import ProcessGroup
# TODO: Remove the toggle-enable_nccl_base_collectives when github open issue #801 is resolved. # TODO: Remove the toggle-enable_nccl_base_collectives when github open issue #801 is resolved.
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0": if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
...@@ -20,7 +22,7 @@ else: ...@@ -20,7 +22,7 @@ else:
class Bucket: class Bucket:
def __init__(self, data: Tensor, group: ProcessGroup): def __init__(self, data: Tensor, group: "ProcessGroup"):
self.data = data self.data = data
self.group = group self.group = group
self.offset = 0 self.offset = 0
...@@ -99,13 +101,13 @@ class ReduceScatterBucketer: ...@@ -99,13 +101,13 @@ class ReduceScatterBucketer:
def __init__(self, bucket_cap_mb: int = 25): def __init__(self, bucket_cap_mb: int = 25):
self.bucket_cap_mb = bucket_cap_mb self.bucket_cap_mb = bucket_cap_mb
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {} self.buckets: Dict[Tuple[torch.dtype, torch.device, "ProcessGroup"], Bucket] = {}
@torch.no_grad() @torch.no_grad()
def reduce_scatter_async( def reduce_scatter_async(
self, self,
input_list: List[Tensor], input_list: List[Tensor],
group: ProcessGroup, group: "ProcessGroup",
callback_fn: Optional[Callable] = None, callback_fn: Optional[Callable] = None,
) -> None: ) -> None:
""" """
...@@ -186,7 +188,7 @@ class ReduceScatterBucketer: ...@@ -186,7 +188,7 @@ class ReduceScatterBucketer:
bucket_size = self.bucket_cap_mb * MB / element_size bucket_size = self.bucket_cap_mb * MB / element_size
return int(bucket_size // num_shards) return int(bucket_size // num_shards)
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket: def _get_bucket(self, tensor: Tensor, group: "ProcessGroup") -> Bucket:
# TODO (Min): the `group` used here in the key is the object hash, not the content # TODO (Min): the `group` used here in the key is the object hash, not the content
# hash. That means if FSDP instances are initialized with different process groups, # hash. That means if FSDP instances are initialized with different process groups,
# even when the group members are in fact the same, we end up creating different # even when the group members are in fact the same, we end up creating different
......
...@@ -5,11 +5,18 @@ ...@@ -5,11 +5,18 @@
from typing import List from typing import List
import torch.distributed as dist
from .checkpoint import checkpoint_wrapper from .checkpoint import checkpoint_wrapper
from .data_parallel import FullyShardedDataParallel, ShardedDataParallel from .data_parallel import FullyShardedDataParallel
if dist.is_available():
# Prevent import failure if dist is not available. #1057
from .data_parallel import ShardedDataParallel
from .moe import MOELayer, Top2Gate
from .pipe import Pipe, PipeRPCWrapper
from .misc import FlattenParamsWrapper from .misc import FlattenParamsWrapper
from .moe import MOELayer, Top2Gate
from .pipe import Pipe, PipeRPCWrapper
from .wrap import auto_wrap, config_auto_wrap_policy, default_auto_wrap_policy, enable_wrap, wrap from .wrap import auto_wrap, config_auto_wrap_policy, default_auto_wrap_policy, enable_wrap, wrap
__all__: List[str] = [] __all__: List[str] = []
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
from typing import List from typing import List
import torch.distributed as dist
from .fully_sharded_data_parallel import ( from .fully_sharded_data_parallel import (
FullyShardedDataParallel, FullyShardedDataParallel,
OffloadConfig, OffloadConfig,
...@@ -12,6 +14,9 @@ from .fully_sharded_data_parallel import ( ...@@ -12,6 +14,9 @@ from .fully_sharded_data_parallel import (
auto_wrap_bn, auto_wrap_bn,
no_pre_load_state_dict_hook, no_pre_load_state_dict_hook,
) )
from .sharded_ddp import ShardedDataParallel
if dist.is_available():
# Prevent import failure if dist is not available. #1057
from .sharded_ddp import ShardedDataParallel
__all__: List[str] = [] __all__: List[str] = []
...@@ -35,7 +35,6 @@ from typing import ( ...@@ -35,7 +35,6 @@ from typing import (
import torch import torch
from torch.autograd import Variable from torch.autograd import Variable
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -58,6 +57,12 @@ from . import fsdp_optim_utils as ou ...@@ -58,6 +57,12 @@ from . import fsdp_optim_utils as ou
if TYPE_CHECKING: if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401 from collections import OrderedDict # noqa: F401
# See #1057. On some platform, torch.distributed may not have ProcessGroup
# So we only import it during type checking, which is not done on default
# import and only done by developer (doing it on supported platforms I presume).
from torch.distributed import ProcessGroup
# TODO: Remove the toggle here when github open issue #801 is resolved. # TODO: Remove the toggle here when github open issue #801 is resolved.
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0": if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
enable_nccl_base_collectives = False enable_nccl_base_collectives = False
...@@ -308,7 +313,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -308,7 +313,7 @@ class FullyShardedDataParallel(nn.Module):
def __init__( def __init__(
self, self,
module: nn.Module, module: nn.Module,
process_group: Optional[ProcessGroup] = None, process_group: Optional["ProcessGroup"] = None,
# The type for the process_group_reduce_scatter only can be either ProcessGroup or ProcessGroupName # The type for the process_group_reduce_scatter only can be either ProcessGroup or ProcessGroupName
process_group_reduce_scatter: Any = ProcessGroupName.reduce_scatter, process_group_reduce_scatter: Any = ProcessGroupName.reduce_scatter,
reshard_after_forward: bool = True, reshard_after_forward: bool = True,
...@@ -352,6 +357,9 @@ class FullyShardedDataParallel(nn.Module): ...@@ -352,6 +357,9 @@ class FullyShardedDataParallel(nn.Module):
self.process_group_reduce_scatter = get_process_group_cached(ProcessGroupName.reduce_scatter) self.process_group_reduce_scatter = get_process_group_cached(ProcessGroupName.reduce_scatter)
else: else:
# If a specific process group is passed in, the reduce_scatter will use the passed in process group. # If a specific process group is passed in, the reduce_scatter will use the passed in process group.
# Delay the import here since this type may not be available on certain platforms.
from torch.distributed import ProcessGroup
if isinstance(process_group_reduce_scatter, ProcessGroup): if isinstance(process_group_reduce_scatter, ProcessGroup):
self.process_group_reduce_scatter = process_group_reduce_scatter self.process_group_reduce_scatter = process_group_reduce_scatter
else: else:
...@@ -2648,7 +2656,7 @@ def _unpad(shard: torch.Tensor, pad: int) -> torch.Tensor: ...@@ -2648,7 +2656,7 @@ def _unpad(shard: torch.Tensor, pad: int) -> torch.Tensor:
def auto_wrap_bn( def auto_wrap_bn(
module: nn.Module, module: nn.Module,
single_rank_pg: bool = False, single_rank_pg: bool = False,
process_group: Optional[ProcessGroup] = None, process_group: Optional["ProcessGroup"] = None,
fsdp_config: Optional[Dict[str, Any]] = None, fsdp_config: Optional[Dict[str, Any]] = None,
wrap_it: bool = True, wrap_it: bool = True,
assert_on_collision: bool = True, assert_on_collision: bool = True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment