Unverified Commit bbac5564 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[minor] use dist.group.WORLD for default process group (#681)



* [minor] use dist.group.WORLD for default process group

- this is slightly more efficient than the previous commit
  for get_process_group_cached.

* fix

* better fix

* fixed for pytorch 1.6 and 1.7

* Update fairscale/utils/parallel.py
Co-authored-by: default avatarMin Xu <min.xu@acm.org>
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 468874c8
......@@ -5,7 +5,7 @@
"""Useful functions for parallel training."""
from typing import List, Optional
from typing import List, Optional, Sequence
import torch
import torch.distributed as dist
......@@ -58,10 +58,14 @@ def enable_pytorch_sync_bn(module: torch.nn.Module) -> None:
layer._specify_ddp_gpu_num(1) # type: ignore
def get_process_group_cached(ranks: Optional[List[int]] = None) -> ProcessGroup:
def get_process_group_cached(ranks: Optional[Sequence[int]] = None) -> ProcessGroup:
"""
Singleton PyTorch distributed group cache. Inspired by the code from fairseq.
Just like torch.distributed.new_group, this method needs to be called on all ranks
at the same time when a new group is created. This is true for all ranks irrespective
of their group membership status.
For FSDP, it is important to use the same group between outer and inner FSDP instances,
otherwise, inner FSDP instances will not share the gradient reduction bucket buffer with
the root instance. This will result in increased GPU memory utilization.
......@@ -87,15 +91,26 @@ def get_process_group_cached(ranks: Optional[List[int]] = None) -> ProcessGroup:
if not dist.is_initialized():
raise RuntimeError("torch.distributed is not yet initialized but process group is requested.")
# Init the cache if needed.
if not hasattr(get_process_group_cached, "_global_group_cache"):
get_process_group_cached._global_group_cache = {} # type: ignore
# Populate with default process group.
cache = get_process_group_cached._global_group_cache # type: ignore
assert dist.group.WORLD is not None
default_pg = dist.group.WORLD
if type(default_pg) == object:
# For PyTorch 1.6 and 1.7, dist.group.WORLD is an object, not a world process group, like that in 1.8 and 1.9.
default_pg = dist.new_group()
cache[None] = default_pg
cache[frozenset(list(range(dist.get_world_size())))] = default_pg
# Lookup and fill the cache if needed.
cache = get_process_group_cached._global_group_cache # type: ignore
if ranks is None:
ranks = list(range(dist.get_world_size()))
ranks_set = frozenset(ranks) # take care of ordering and duplicates in the ranks list.
if ranks_set not in cache:
cache[ranks_set] = dist.new_group(list(ranks_set))
return cache[ranks_set]
if ranks is not None:
# take care of ordering and duplicates in the ranks list. use tuple so that ranks
# can be used as a cache index.
ranks = tuple(sorted(list(set(ranks))))
if ranks not in cache:
cache[ranks] = dist.new_group(ranks=ranks)
return cache[ranks]
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Any, List, Union, Optional
from typing import Any, List, Union, Optional, Sequence
from torch import Tensor
import datetime
......@@ -37,7 +37,7 @@ def broadcast_object_list(object_list: List[Any], src: int, group:Optional[Proce
def is_initialized() -> bool: ...
def init_process_group(backend: Union[str, Backend], init_method: Optional[str] = None, timeout: datetime.timedelta = datetime.timedelta(0, 1800), rank: Optional[int] = None, world_size: Optional[int] = None): ...
def new_group(ranks: Optional[List[int]] = None,
def new_group(ranks: Optional[Sequence[int]] = None,
timeout: Optional[datetime.timedelta] = datetime.timedelta(0, 1800),
backend: Optional[Union[str, Backend]] = None): ...
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment