"git@developer.sourcefind.cn:OpenDAS/autoawq_kernels.git" did not exist on "e90b731a667aa1efae0edc64ac120a07a844ee2c"
Unverified Commit bde4bac5 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] add and use get_process_group_cached (#678)

* [fix] add and use get_process_group_cached

- This commit makes FSDP avoid making too many process groups by default
- Extra process group is bad for GPU memory and init time

* add changelog

* lint

* note on speed

* add better assert output

test seems to be flaky:
https://app.circleci.com/pipelines/github/facebookresearch/fairscale/2957/workflows/383c9f9f-f1a5-461c-8c41-e2e28ece037b/jobs/26783/steps



* update test reference memory values

- With cached process groups, the memory is reduced as reported by
pytorch as well (due to bucket buffer memory for the reduction buffer)
- The effect on memory is actually more on the SMI memory, which is not
reported by pytorch and checked by this test.

* Update fairscale/nn/data_parallel/fully_sharded_data_parallel.py

* Update fairscale/nn/data_parallel/fully_sharded_data_parallel.py

* Update CHANGELOG.md

* Update fairscale/utils/parallel.py

* Update fairscale/utils/parallel.py

* Update fairscale/utils/parallel.py

* Update fairscale/utils/parallel.py

* improved changelog

* better handling of underscores in the md file
Co-authored-by: default avatarMin Xu <min.xu@acm.org>
parent 72c6bab2
......@@ -6,20 +6,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD
### Fixed
- FSDP: fix forward pass not overlapping compute and all-gather
- FSDP: improved frozen weight support
- FSDP: workaround AMP autocast cache issue with clear\_autocast\_cache flag
- setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar
- SDP: re-expose the module property ([#647](https://github.com/facebookresearch/fairscale/pull/647))
- Cleanup - rename and move the checkpoint_activations wrapper ([654]https://github.com/facebookresearch/fairscale/pull/654)
- FSDP: fix extra process groups being created by default. Old behavior can cause excessive GPU memory usage. [#678]
- FSDP: fix forward pass not overlapping compute and allgather [#671]
- FSDP: improved frozen weight support [#657]
- FSDP: workaround AMP autocast cache issue with `clear_autocast_cache` flag [#650]
- MoE: several fixes [#666] [#667] [#668]
- setup.py: hide CUDA extensions behind `BUILD_CUDA_EXTENSIONS` envvar [#634]
- SDP: re-expose the module property [#647]
- Cleanup - rename and move the `checkpoint_activations` wrapper [#654]
### Added
- FSDP: added `force\_input\_to\_fp32` flag for SyncBatchNorm
- FSDP: better memory usage for reduce bucket ([#633](https://github.com/facebookresearch/fairscale/pull/633))
- FSDP: added `force_input_to_fp32` flag for SyncBatchNorm [#659]
- FSDP: better memory usage for reduce bucket [#633]
- Experimental SyncBatchNorm [#662]
## [0.3.6] - 2021-04-26
### Added
- FSDP: Consolidate cpu_adam optimizer state dict ([#607](https://github.com/facebookresearch/fairscale/pull/607))
- FSDP: Consolidate cpu\_adam optimizer state dict ([#607](https://github.com/facebookresearch/fairscale/pull/607))
### Fixed
- FSDP: handle model with multiple forward pass and checkpoint ([#621](https://github.com/facebookresearch/fairscale/pull/621))
......@@ -52,7 +55,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.3.3] - 2021-04-1
### Added
- FSDP: changed auto\_wrap\_bn utility function so that single FSDP group is optional ([#556](https://github.com/facebookresearch/fairscale/pull/556))
- FSDP: changed `auto_wrap_bn` utility function so that single FSDP group is optional ([#556](https://github.com/facebookresearch/fairscale/pull/556))
- FSDP: optimizer state load/save ([#537](https://github.com/facebookresearch/fairscale/pull/537))
- FSDP: fix weight init when using apply() ([#543](https://github.com/facebookresearch/fairscale/pull/543))
- Multiprocess Pipe: retired old implementation
......@@ -65,7 +68,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Experimental: Add spectrain support ([#372](https://github.com/facebookresearch/fairscale/issues/372))
- FSDP: enabled pytorch SyncBN (no asserting) ([#527](https://github.com/facebookresearch/fairscale/issues/527))
- FSDP: added auto\_wrap\_bn utility function ([#531](https://github.com/facebookresearch/fairscale/pull/531))
- FSDP: added `auto_wrap_bn` utility function ([#531](https://github.com/facebookresearch/fairscale/pull/531))
### Fixed
- OSS: fix a compatibily problem with lightning wrt optimizer state dict ([#510](https://github.com/facebookresearch/fairscale/issues/510))
......@@ -74,7 +77,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.3.1] - 2021-03-09
### Added
- FSDP docs ([#455](https://github.com/facebookresearch/fairscale/issues/455))
- enable\_wrap and auto\_wrap APIs ([#446](https://github.com/facebookresearch/fairscale/issues/446))
- `enable_wrap` and `auto_wrap` APIs ([#446](https://github.com/facebookresearch/fairscale/issues/446))
- Added experimental.nn.OffloadModel API for training large models on a single GPU.([#432](https://github.com/facebookresearch/fairscale/issues/432))
### Fixed
......@@ -91,7 +94,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Catch corner case when the model is too small with respect to the world size, and shards are empty ([#406](https://github.com/facebookresearch/fairscale/pull/406))
- Memory leak in checkpoint\_wrapper ([#412](https://github.com/facebookresearch/fairscale/pull/412))
- Memory leak in `checkpoint_wrapper` ([#412](https://github.com/facebookresearch/fairscale/pull/412))
## [0.1.7] - 2021-02-19
### Fixed
......@@ -141,9 +144,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- AdaScale:
. Added gradient accumulation feature (#202)
. Added support of torch.lr_scheduler (#229)
. Added support for add_param_groups (#266)
. Added support for scale != world_size (#266)
. Added support of `torch.lr_scheduler` (#229)
. Added support for `add_param_groups` (#266)
. Added support for `scale != world_size` (#266)
### Fixed
- AdaScale: smoothing factor value fixed when using gradient accumulation (#235)
......@@ -162,7 +165,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- cpu support for Pipe (#188)
- ShardedOptim: Distributed Grad Scaler (for torch AMP) (#182)
- OSS-aware clip grads, bridge sharded states (#167)
- oss: add rank_local_state_dict staticmethod (#174)
- oss: add `rank_local_state_dict` staticmethod (#174)
- support for PyTorch 1.7.0 (#171)
- Add implementation of AdaScale (#139)
......@@ -179,7 +182,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.0.2] - 2020-08-28
### Added
- add ddp that works with oss with reduce() not all_reduce() (#19)
- add ddp that works with oss with `reduce()` not `all_reduce()` (#19)
- support for PyTorch v1.6
- add mixed precision Adam (#40)
- Adam optimizer state scaling (#44)
......@@ -189,7 +192,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- OSS restore state to proper device (#46)
- optim/oss: support optimizers with additional step kwargs (#53)
- optim/oss: fix state cast (#56)
- fix eval for oss_ddp (#55)
- fix eval for `oss_ddp` (#55)
- optim/oss: work correctly with LRScheduler (#58)
## [0.0.1] - 2020-07-31
......
......@@ -24,7 +24,12 @@ import torch.nn.functional as F
from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap
from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group
from fairscale.utils.parallel import (
chunk_and_pad,
enable_pytorch_sync_bn,
get_process_group_cached,
validate_process_group,
)
from fairscale.utils.params import broadcast_object, calc_grad_norm, recursive_copy_to_device
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_
......@@ -233,7 +238,7 @@ class FullyShardedDataParallel(nn.Module):
):
init_start = time.time()
super().__init__()
self.process_group = process_group or dist.new_group()
self.process_group = process_group or get_process_group_cached()
self.rank = self.process_group.rank()
self.world_size = self.process_group.size()
self.reshard_after_forward = reshard_after_forward
......@@ -1816,13 +1821,11 @@ def auto_wrap_bn(
module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES) # type: ignore
)
pg = None
pg = process_group
if single_rank_pg:
# No sharding with this single member group.
my_rank = dist.get_rank()
pg = dist.new_group(ranks=[my_rank])
else:
pg = process_group
pg = get_process_group_cached(ranks=[my_rank])
if fsdp_config is None:
fsdp_config = {
......
......@@ -5,7 +5,7 @@
"""Useful functions for parallel training."""
from typing import List
from typing import List, Optional
import torch
import torch.distributed as dist
......@@ -58,18 +58,44 @@ def enable_pytorch_sync_bn(module: torch.nn.Module) -> None:
layer._specify_ddp_gpu_num(1) # type: ignore
def get_global_group() -> None:
def get_process_group_cached(ranks: Optional[List[int]] = None) -> ProcessGroup:
"""
Singleton PyTorch distributed group.
Inspired by https://github.com/pytorch/fairseq
Singleton PyTorch distributed group cache. Inspired by the code from fairseq.
For FSDP, it is important to use a global group, otherwise, inner FSDP instances
will not share the gradient reduction bucket buffer with the root instance, end up using
more GPU memory.
For FSDP, it is important to use the same group between outer and inner FSDP instances,
otherwise, inner FSDP instances will not share the gradient reduction bucket buffer with
the root instance. This will result in increased GPU memory utilization.
Each separate process group also uses separate NCCL library instances, which will have
a significant effect on GPU memory use if too many process groups are created and used.
Setting NCCL_BUFFSIZE=102400 env variable is a useful technique to check if the NCCL
memory is causing GPU OOM. Note, the NCCL buffers are not allocated
through the PyTorch caching allocator, therefore, you may see GPU OOM even when
torch.cuda.reserved_memory() is still way below the total amount of GPU memory.
Extra process groups can also reduce training speed (observed on VISSL models).
Args:
ranks (Optional[List[int]]):
Ranks requested in the target group. None for all ranks.
Default: None
Returns:
(ProcessGroup):
Return the requested process group. Throws RuntimeError if torch.distributed module is not yet initialized.
"""
if dist.is_initialized():
if not hasattr(get_global_group, "_global_group"):
get_global_group._global_group = dist.new_group() # type: ignore
return get_global_group._global_group # type: ignore
else:
return None
if not dist.is_initialized():
raise RuntimeError("torch.distributed is not yet initialized but process group is requested.")
if not hasattr(get_process_group_cached, "_global_group_cache"):
get_process_group_cached._global_group_cache = {} # type: ignore
cache = get_process_group_cached._global_group_cache # type: ignore
if ranks is None:
ranks = list(range(dist.get_world_size()))
ranks_set = frozenset(ranks) # take care of ordering and duplicates in the ranks list.
if ranks_set not in cache:
cache[ranks_set] = dist.new_group(list(ranks_set))
return cache[ranks_set]
......@@ -21,7 +21,7 @@ import torch.optim as optim
from fairscale.nn import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import auto_wrap_bn
from fairscale.utils.parallel import get_global_group
from fairscale.utils.parallel import get_process_group_cached
from fairscale.utils.testing import (
dist_init,
dump_all_tensors,
......@@ -33,7 +33,7 @@ from fairscale.utils.testing import (
def to_fsdp(module, fsdp_config):
return FSDP(module, process_group=get_global_group(), **fsdp_config)
return FSDP(module, process_group=get_process_group_cached(), **fsdp_config)
def get_cur_mem(rank, result, prefix):
......@@ -199,78 +199,78 @@ def test_fsdp_memory(fsdp, ckpt):
"iter 0: start": 3,
"iter 0: after fwd": 340,
"iter 0: after loss": 340,
"iter 0: after bwd": 66,
"iter 0: after step": 68,
"iter 0: after bwd": 16,
"iter 0: after step": 18,
"iter 0: done": 5,
"iter 1: start": 5,
"iter 1: after fwd": 342,
"iter 1: after loss": 342,
"iter 1: after bwd": 68,
"iter 1: after step": 68,
"iter 1: after bwd": 18,
"iter 1: after step": 18,
"iter 1: done": 5,
"iter 2: start": 5,
"iter 2: after fwd": 342,
"iter 2: after loss": 342,
"iter 2: after bwd": 68,
"iter 2: after step": 68,
"iter 2: after bwd": 18,
"iter 2: after step": 18,
"iter 2: done": 5,
"iter 3: start": 5,
"iter 3: after fwd": 342,
"iter 3: after loss": 342,
"iter 3: after bwd": 68,
"iter 3: after step": 68,
"iter 3: after bwd": 18,
"iter 3: after step": 18,
"iter 3: done": 5,
},
("fsdp_amp_default", "no_ckpt"): {
"iter 0: start": 28,
"iter 0: after fwd": 630,
"iter 0: after loss": 630,
"iter 0: after bwd": 104,
"iter 0: after step": 131,
"iter 0: after bwd": 67,
"iter 0: after step": 93,
"iter 0: done": 54,
"iter 1: start": 54,
"iter 1: after fwd": 657,
"iter 1: after loss": 657,
"iter 1: after bwd": 131,
"iter 1: after step": 131,
"iter 1: after bwd": 93,
"iter 1: after step": 93,
"iter 1: done": 54,
"iter 2: start": 54,
"iter 2: after fwd": 657,
"iter 2: after loss": 657,
"iter 2: after bwd": 131,
"iter 2: after step": 131,
"iter 2: after bwd": 93,
"iter 2: after step": 93,
"iter 2: done": 54,
"iter 3: start": 54,
"iter 3: after fwd": 657,
"iter 3: after loss": 657,
"iter 3: after bwd": 131,
"iter 3: after step": 131,
"iter 3: after bwd": 93,
"iter 3: after step": 93,
"iter 3: done": 54,
},
("fsdp_amp_compute_dtype32", "no_ckpt"): {
"iter 0: start": 28,
"iter 0: after fwd": 657,
"iter 0: after loss": 657,
"iter 0: after bwd": 117,
"iter 0: after step": 143,
"iter 0: after bwd": 67,
"iter 0: after step": 93,
"iter 0: done": 54,
"iter 1: start": 54,
"iter 1: after fwd": 684,
"iter 1: after loss": 684,
"iter 1: after bwd": 143,
"iter 1: after step": 143,
"iter 1: after bwd": 93,
"iter 1: after step": 93,
"iter 1: done": 54,
"iter 2: start": 54,
"iter 2: after fwd": 684,
"iter 2: after loss": 684,
"iter 2: after bwd": 143,
"iter 2: after step": 143,
"iter 2: after bwd": 93,
"iter 2: after step": 93,
"iter 2: done": 54,
"iter 3: start": 54,
"iter 3: after fwd": 684,
"iter 3: after loss": 684,
"iter 3: after bwd": 143,
"iter 3: after step": 143,
"iter 3: after bwd": 93,
"iter 3: after step": 93,
"iter 3: done": 54,
},
("ddp", "ckpt"): {
......@@ -303,78 +303,78 @@ def test_fsdp_memory(fsdp, ckpt):
"iter 0: start": 3,
"iter 0: after fwd": 51,
"iter 0: after loss": 51,
"iter 0: after bwd": 66,
"iter 0: after step": 68,
"iter 0: after bwd": 16,
"iter 0: after step": 18,
"iter 0: done": 5,
"iter 1: start": 5,
"iter 1: after fwd": 53,
"iter 1: after loss": 53,
"iter 1: after bwd": 68,
"iter 1: after step": 68,
"iter 1: after bwd": 18,
"iter 1: after step": 18,
"iter 1: done": 5,
"iter 2: start": 5,
"iter 2: after fwd": 53,
"iter 2: after loss": 53,
"iter 2: after bwd": 68,
"iter 2: after step": 68,
"iter 2: after bwd": 18,
"iter 2: after step": 18,
"iter 2: done": 5,
"iter 3: start": 5,
"iter 3: after fwd": 53,
"iter 3: after loss": 53,
"iter 3: after bwd": 68,
"iter 3: after step": 68,
"iter 3: after bwd": 18,
"iter 3: after step": 18,
"iter 3: done": 5,
},
("fsdp_amp_default", "ckpt"): {
"iter 0: start": 28,
"iter 0: after fwd": 52,
"iter 0: after loss": 52,
"iter 0: after bwd": 104,
"iter 0: after step": 131,
"iter 0: after bwd": 67,
"iter 0: after step": 93,
"iter 0: done": 54,
"iter 1: start": 54,
"iter 1: after fwd": 79,
"iter 1: after loss": 79,
"iter 1: after bwd": 131,
"iter 1: after step": 131,
"iter 1: after bwd": 93,
"iter 1: after step": 93,
"iter 1: done": 54,
"iter 2: start": 54,
"iter 2: after fwd": 79,
"iter 2: after loss": 79,
"iter 2: after bwd": 131,
"iter 2: after step": 131,
"iter 2: after bwd": 93,
"iter 2: after step": 93,
"iter 2: done": 54,
"iter 3: start": 54,
"iter 3: after fwd": 79,
"iter 3: after loss": 79,
"iter 3: after bwd": 131,
"iter 3: after step": 131,
"iter 3: after bwd": 93,
"iter 3: after step": 93,
"iter 3: done": 54,
},
("fsdp_amp_compute_dtype32", "ckpt"): {
"iter 0: start": 28,
"iter 0: after fwd": 52,
"iter 0: after loss": 52,
"iter 0: after bwd": 117,
"iter 0: after step": 143,
"iter 0: after bwd": 67,
"iter 0: after step": 93,
"iter 0: done": 54,
"iter 1: start": 54,
"iter 1: after fwd": 79,
"iter 1: after loss": 79,
"iter 1: after bwd": 143,
"iter 1: after step": 143,
"iter 1: after bwd": 93,
"iter 1: after step": 93,
"iter 1: done": 54,
"iter 2: start": 54,
"iter 2: after fwd": 79,
"iter 2: after loss": 79,
"iter 2: after bwd": 143,
"iter 2: after step": 143,
"iter 2: after bwd": 93,
"iter 2: after step": 93,
"iter 2: done": 54,
"iter 3: start": 54,
"iter 3: after fwd": 79,
"iter 3: after loss": 79,
"iter 3: after bwd": 143,
"iter 3: after step": 143,
"iter 3: after bwd": 93,
"iter 3: after step": 93,
"iter 3: done": 54,
},
}[(fsdp, ckpt)]
......
......@@ -62,7 +62,7 @@ def expert_params(device):
expert = torch.nn.Linear(model_dim, model_dim)
moe = MOELayer(gate, expert).to(device)
for p in expert.parameters():
assert p.expert is True
assert p.expert is True, str(p.expert)
@pg_test()
......@@ -77,9 +77,9 @@ def forward(device):
expert.weight = torch.nn.Parameter(torch.eye(model_dim))
moe = MOELayer(gate, expert).to(device)
output = moe(input)
assert output.shape == input.shape
assert output.shape == input.shape, f"{output.shape} != {input.shape}"
# Re-assembled output should match input due to identity expert.
assert torch.allclose(input, output)
torch.testing.assert_allclose(input, output)
@pg_test()
......@@ -99,11 +99,13 @@ def forward_multi(device):
experts += [expert]
moe = MOELayer(gate, torch.nn.ModuleList(experts)).to(device)
output = moe(input)
assert output.shape == input.shape
assert output.shape == input.shape, f"{output.shape} != {input.shape}"
# 90% of the input should have gone to an expert
assert len(output.nonzero(as_tuple=False)) / output.numel() > 0.90
assert (
len(output.nonzero(as_tuple=False)) / output.numel() > 0.90
), f"{len(output.nonzero(as_tuple=False))} / {output.numel()}"
# Except for zeros, re-assembled output should match input due to identity expert.
assert torch.allclose(input, torch.where(output > 0, output, input))
torch.testing.assert_allclose(input, torch.where(output > 0, output, input))
# Test Gate which round-robin routes tokens to experts
......@@ -115,7 +117,7 @@ class RoundRobinGate(torch.nn.Module):
def forward(self, input):
s = input.shape[0]
assert s % self.num_experts == 0
assert s % self.num_experts == 0, f"{s} % {self.num_experts} != 0"
capacity = 2 * s // self.num_experts
output = torch.zeros(s, self.num_experts, capacity, dtype=input.dtype, device=input.device)
for i in range(s):
......@@ -136,12 +138,12 @@ def forward_routing(device):
expert.weight = torch.nn.Parameter(torch.eye(model_dim) * scale)
moe = MOELayer(gate, expert).to(device)
output = moe(input)
assert output.shape == input.shape
assert output.shape == input.shape, f"{output.shape} != {input.shape}"
# Verify that each token was sent to the correct expert by checking its scale.
t = input.shape[1]
for i in range(t):
expert = i % num_experts
assert torch.allclose(input[:, i] * (expert + 1), output[:, i])
torch.testing.assert_allclose(input[:, i] * (expert + 1), output[:, i])
@pg_test()
......@@ -161,12 +163,12 @@ def forward_routing_multi(device):
experts += [expert]
moe = MOELayer(gate, torch.nn.ModuleList(experts)).to(device)
output = moe(input)
assert output.shape == input.shape
assert output.shape == input.shape, f"{output.shape} != {input.shape}"
# Verify that each token was sent to the correct expert by checking its scale.
t = input.shape[1]
for i in range(t):
expert = i % num_experts
assert torch.allclose(input[:, i] * (expert + 1), output[:, i])
torch.testing.assert_allclose(input[:, i] * (expert + 1), output[:, i])
@pg_test()
......@@ -182,7 +184,7 @@ def backward(device):
expert.weight = torch.nn.Parameter(torch.eye(model_dim))
moe = MOELayer(gate, expert).to(device)
output = moe(input)
assert output.shape == input.shape
assert output.shape == input.shape, f"{output.shape} != {input.shape}"
output = loss(output, input)
output.backward()
assert torch.allclose(expert.weight.grad, torch.zeros_like(expert.weight))
torch.testing.assert_allclose(expert.weight.grad, torch.zeros_like(expert.weight))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment