Unverified Commit 47ce21ac authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] don't import ProcessGroup eagerly (#1074)



* [fix] don't import ProcessGroup eagerly

- move the import into typing since it is only used for type checking
- fixes #1057

* more fixes

* one more

* tested at least
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent d8fc94d9
......@@ -7,13 +7,16 @@
from enum import Enum
import sys
from typing import List, Optional, Sequence
from typing import TYPE_CHECKING, List, Optional, Sequence
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn.functional as F
if TYPE_CHECKING:
# See comments in FSDP code for reason of this import.
from torch.distributed import ProcessGroup
def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
"""Chunk a given Tensor into num_chunks parts and add any necessary padding."""
......@@ -27,7 +30,7 @@ def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
return chunks
def validate_process_group(device: torch.device, process_group: ProcessGroup) -> None:
def validate_process_group(device: torch.device, process_group: "ProcessGroup") -> None:
"""Do a quick test in case user called FSDP without calling torch.cuda.set_device()
correctly. This can easily happen in cpu_offload case where the model resides on
the CPU.
......@@ -67,7 +70,7 @@ class ProcessGroupName(str, Enum):
def get_process_group_cached(
name: ProcessGroupName = ProcessGroupName.default, ranks: Optional[Sequence[int]] = None
) -> ProcessGroup:
) -> "ProcessGroup":
"""
Singleton PyTorch distributed group cache. Inspired by the code from fairseq.
......
......@@ -5,12 +5,14 @@
import functools
import os
from typing import Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
import torch.distributed as dist
from torch.distributed import ProcessGroup
if TYPE_CHECKING:
from torch.distributed import ProcessGroup
# TODO: Remove the toggle-enable_nccl_base_collectives when github open issue #801 is resolved.
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
......@@ -20,7 +22,7 @@ else:
class Bucket:
def __init__(self, data: Tensor, group: ProcessGroup):
def __init__(self, data: Tensor, group: "ProcessGroup"):
self.data = data
self.group = group
self.offset = 0
......@@ -99,13 +101,13 @@ class ReduceScatterBucketer:
def __init__(self, bucket_cap_mb: int = 25):
self.bucket_cap_mb = bucket_cap_mb
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}
self.buckets: Dict[Tuple[torch.dtype, torch.device, "ProcessGroup"], Bucket] = {}
@torch.no_grad()
def reduce_scatter_async(
self,
input_list: List[Tensor],
group: ProcessGroup,
group: "ProcessGroup",
callback_fn: Optional[Callable] = None,
) -> None:
"""
......@@ -186,7 +188,7 @@ class ReduceScatterBucketer:
bucket_size = self.bucket_cap_mb * MB / element_size
return int(bucket_size // num_shards)
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
def _get_bucket(self, tensor: Tensor, group: "ProcessGroup") -> Bucket:
# TODO (Min): the `group` used here in the key is the object hash, not the content
# hash. That means if FSDP instances are initialized with different process groups,
# even when the group members are in fact the same, we end up creating different
......
......@@ -5,11 +5,18 @@
from typing import List
import torch.distributed as dist
from .checkpoint import checkpoint_wrapper
from .data_parallel import FullyShardedDataParallel, ShardedDataParallel
from .data_parallel import FullyShardedDataParallel
if dist.is_available():
# Prevent import failure if dist is not available. #1057
from .data_parallel import ShardedDataParallel
from .moe import MOELayer, Top2Gate
from .pipe import Pipe, PipeRPCWrapper
from .misc import FlattenParamsWrapper
from .moe import MOELayer, Top2Gate
from .pipe import Pipe, PipeRPCWrapper
from .wrap import auto_wrap, config_auto_wrap_policy, default_auto_wrap_policy, enable_wrap, wrap
__all__: List[str] = []
......@@ -5,6 +5,8 @@
from typing import List
import torch.distributed as dist
from .fully_sharded_data_parallel import (
FullyShardedDataParallel,
OffloadConfig,
......@@ -12,6 +14,9 @@ from .fully_sharded_data_parallel import (
auto_wrap_bn,
no_pre_load_state_dict_hook,
)
from .sharded_ddp import ShardedDataParallel
if dist.is_available():
# Prevent import failure if dist is not available. #1057
from .sharded_ddp import ShardedDataParallel
__all__: List[str] = []
......@@ -35,7 +35,6 @@ from typing import (
import torch
from torch.autograd import Variable
import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
......@@ -58,6 +57,12 @@ from . import fsdp_optim_utils as ou
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
# See #1057. On some platform, torch.distributed may not have ProcessGroup
# So we only import it during type checking, which is not done on default
# import and only done by developer (doing it on supported platforms I presume).
from torch.distributed import ProcessGroup
# TODO: Remove the toggle here when github open issue #801 is resolved.
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
enable_nccl_base_collectives = False
......@@ -308,7 +313,7 @@ class FullyShardedDataParallel(nn.Module):
def __init__(
self,
module: nn.Module,
process_group: Optional[ProcessGroup] = None,
process_group: Optional["ProcessGroup"] = None,
# The type for the process_group_reduce_scatter only can be either ProcessGroup or ProcessGroupName
process_group_reduce_scatter: Any = ProcessGroupName.reduce_scatter,
reshard_after_forward: bool = True,
......@@ -352,6 +357,9 @@ class FullyShardedDataParallel(nn.Module):
self.process_group_reduce_scatter = get_process_group_cached(ProcessGroupName.reduce_scatter)
else:
# If a specific process group is passed in, the reduce_scatter will use the passed in process group.
# Delay the import here since this type may not be available on certain platforms.
from torch.distributed import ProcessGroup
if isinstance(process_group_reduce_scatter, ProcessGroup):
self.process_group_reduce_scatter = process_group_reduce_scatter
else:
......@@ -2648,7 +2656,7 @@ def _unpad(shard: torch.Tensor, pad: int) -> torch.Tensor:
def auto_wrap_bn(
module: nn.Module,
single_rank_pg: bool = False,
process_group: Optional[ProcessGroup] = None,
process_group: Optional["ProcessGroup"] = None,
fsdp_config: Optional[Dict[str, Any]] = None,
wrap_it: bool = True,
assert_on_collision: bool = True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment