Unverified Commit bde4bac5 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] add and use get_process_group_cached (#678)

* [fix] add and use get_process_group_cached

- This commit makes FSDP avoid making too many process groups by default
- Extra process group is bad for GPU memory and init time

* add changelog

* lint

* note on speed

* add better assert output

test seems to be flaky:
https://app.circleci.com/pipelines/github/facebookresearch/fairscale/2957/workflows/383c9f9f-f1a5-461c-8c41-e2e28ece037b/jobs/26783/steps



* update test reference memory values

- With cached process groups, the memory is reduced as reported by
pytorch as well (due to bucket buffer memory for the reduction buffer)
- The effect on memory is actually more on the SMI memory, which is not
reported by pytorch and checked by this test.

* Update fairscale/nn/data_parallel/fully_sharded_data_parallel.py

* Update fairscale/nn/data_parallel/fully_sharded_data_parallel.py

* Update CHANGELOG.md

* Update fairscale/utils/parallel.py

* Update fairscale/utils/parallel.py

* Update fairscale/utils/parallel.py

* Update fairscale/utils/parallel.py

* improved changelog

* better handling of underscores in the md file
Co-authored-by: default avatarMin Xu <min.xu@acm.org>
parent 72c6bab2
...@@ -6,20 +6,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -6,20 +6,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD ## NEXT - TBD
### Fixed ### Fixed
- FSDP: fix forward pass not overlapping compute and all-gather - FSDP: fix extra process groups being created by default. Old behavior can cause excessive GPU memory usage. [#678]
- FSDP: improved frozen weight support - FSDP: fix forward pass not overlapping compute and allgather [#671]
- FSDP: workaround AMP autocast cache issue with clear\_autocast\_cache flag - FSDP: improved frozen weight support [#657]
- setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar - FSDP: workaround AMP autocast cache issue with `clear_autocast_cache` flag [#650]
- SDP: re-expose the module property ([#647](https://github.com/facebookresearch/fairscale/pull/647)) - MoE: several fixes [#666] [#667] [#668]
- Cleanup - rename and move the checkpoint_activations wrapper ([654]https://github.com/facebookresearch/fairscale/pull/654) - setup.py: hide CUDA extensions behind `BUILD_CUDA_EXTENSIONS` envvar [#634]
- SDP: re-expose the module property [#647]
- Cleanup - rename and move the `checkpoint_activations` wrapper [#654]
### Added ### Added
- FSDP: added `force\_input\_to\_fp32` flag for SyncBatchNorm - FSDP: added `force_input_to_fp32` flag for SyncBatchNorm [#659]
- FSDP: better memory usage for reduce bucket ([#633](https://github.com/facebookresearch/fairscale/pull/633)) - FSDP: better memory usage for reduce bucket [#633]
- Experimental SyncBatchNorm [#662]
## [0.3.6] - 2021-04-26 ## [0.3.6] - 2021-04-26
### Added ### Added
- FSDP: Consolidate cpu_adam optimizer state dict ([#607](https://github.com/facebookresearch/fairscale/pull/607)) - FSDP: Consolidate cpu\_adam optimizer state dict ([#607](https://github.com/facebookresearch/fairscale/pull/607))
### Fixed ### Fixed
- FSDP: handle model with multiple forward pass and checkpoint ([#621](https://github.com/facebookresearch/fairscale/pull/621)) - FSDP: handle model with multiple forward pass and checkpoint ([#621](https://github.com/facebookresearch/fairscale/pull/621))
...@@ -52,7 +55,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -52,7 +55,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.3.3] - 2021-04-1 ## [0.3.3] - 2021-04-1
### Added ### Added
- FSDP: changed auto\_wrap\_bn utility function so that single FSDP group is optional ([#556](https://github.com/facebookresearch/fairscale/pull/556)) - FSDP: changed `auto_wrap_bn` utility function so that single FSDP group is optional ([#556](https://github.com/facebookresearch/fairscale/pull/556))
- FSDP: optimizer state load/save ([#537](https://github.com/facebookresearch/fairscale/pull/537)) - FSDP: optimizer state load/save ([#537](https://github.com/facebookresearch/fairscale/pull/537))
- FSDP: fix weight init when using apply() ([#543](https://github.com/facebookresearch/fairscale/pull/543)) - FSDP: fix weight init when using apply() ([#543](https://github.com/facebookresearch/fairscale/pull/543))
- Multiprocess Pipe: retired old implementation - Multiprocess Pipe: retired old implementation
...@@ -65,7 +68,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -65,7 +68,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added ### Added
- Experimental: Add spectrain support ([#372](https://github.com/facebookresearch/fairscale/issues/372)) - Experimental: Add spectrain support ([#372](https://github.com/facebookresearch/fairscale/issues/372))
- FSDP: enabled pytorch SyncBN (no asserting) ([#527](https://github.com/facebookresearch/fairscale/issues/527)) - FSDP: enabled pytorch SyncBN (no asserting) ([#527](https://github.com/facebookresearch/fairscale/issues/527))
- FSDP: added auto\_wrap\_bn utility function ([#531](https://github.com/facebookresearch/fairscale/pull/531)) - FSDP: added `auto_wrap_bn` utility function ([#531](https://github.com/facebookresearch/fairscale/pull/531))
### Fixed ### Fixed
- OSS: fix a compatibily problem with lightning wrt optimizer state dict ([#510](https://github.com/facebookresearch/fairscale/issues/510)) - OSS: fix a compatibily problem with lightning wrt optimizer state dict ([#510](https://github.com/facebookresearch/fairscale/issues/510))
...@@ -74,7 +77,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -74,7 +77,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.3.1] - 2021-03-09 ## [0.3.1] - 2021-03-09
### Added ### Added
- FSDP docs ([#455](https://github.com/facebookresearch/fairscale/issues/455)) - FSDP docs ([#455](https://github.com/facebookresearch/fairscale/issues/455))
- enable\_wrap and auto\_wrap APIs ([#446](https://github.com/facebookresearch/fairscale/issues/446)) - `enable_wrap` and `auto_wrap` APIs ([#446](https://github.com/facebookresearch/fairscale/issues/446))
- Added experimental.nn.OffloadModel API for training large models on a single GPU.([#432](https://github.com/facebookresearch/fairscale/issues/432)) - Added experimental.nn.OffloadModel API for training large models on a single GPU.([#432](https://github.com/facebookresearch/fairscale/issues/432))
### Fixed ### Fixed
...@@ -91,7 +94,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -91,7 +94,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed ### Fixed
- Catch corner case when the model is too small with respect to the world size, and shards are empty ([#406](https://github.com/facebookresearch/fairscale/pull/406)) - Catch corner case when the model is too small with respect to the world size, and shards are empty ([#406](https://github.com/facebookresearch/fairscale/pull/406))
- Memory leak in checkpoint\_wrapper ([#412](https://github.com/facebookresearch/fairscale/pull/412)) - Memory leak in `checkpoint_wrapper` ([#412](https://github.com/facebookresearch/fairscale/pull/412))
## [0.1.7] - 2021-02-19 ## [0.1.7] - 2021-02-19
### Fixed ### Fixed
...@@ -141,9 +144,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -141,9 +144,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added ### Added
- AdaScale: - AdaScale:
. Added gradient accumulation feature (#202) . Added gradient accumulation feature (#202)
. Added support of torch.lr_scheduler (#229) . Added support of `torch.lr_scheduler` (#229)
. Added support for add_param_groups (#266) . Added support for `add_param_groups` (#266)
. Added support for scale != world_size (#266) . Added support for `scale != world_size` (#266)
### Fixed ### Fixed
- AdaScale: smoothing factor value fixed when using gradient accumulation (#235) - AdaScale: smoothing factor value fixed when using gradient accumulation (#235)
...@@ -162,7 +165,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -162,7 +165,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- cpu support for Pipe (#188) - cpu support for Pipe (#188)
- ShardedOptim: Distributed Grad Scaler (for torch AMP) (#182) - ShardedOptim: Distributed Grad Scaler (for torch AMP) (#182)
- OSS-aware clip grads, bridge sharded states (#167) - OSS-aware clip grads, bridge sharded states (#167)
- oss: add rank_local_state_dict staticmethod (#174) - oss: add `rank_local_state_dict` staticmethod (#174)
- support for PyTorch 1.7.0 (#171) - support for PyTorch 1.7.0 (#171)
- Add implementation of AdaScale (#139) - Add implementation of AdaScale (#139)
...@@ -179,7 +182,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -179,7 +182,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.0.2] - 2020-08-28 ## [0.0.2] - 2020-08-28
### Added ### Added
- add ddp that works with oss with reduce() not all_reduce() (#19) - add ddp that works with oss with `reduce()` not `all_reduce()` (#19)
- support for PyTorch v1.6 - support for PyTorch v1.6
- add mixed precision Adam (#40) - add mixed precision Adam (#40)
- Adam optimizer state scaling (#44) - Adam optimizer state scaling (#44)
...@@ -189,7 +192,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -189,7 +192,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- OSS restore state to proper device (#46) - OSS restore state to proper device (#46)
- optim/oss: support optimizers with additional step kwargs (#53) - optim/oss: support optimizers with additional step kwargs (#53)
- optim/oss: fix state cast (#56) - optim/oss: fix state cast (#56)
- fix eval for oss_ddp (#55) - fix eval for `oss_ddp` (#55)
- optim/oss: work correctly with LRScheduler (#58) - optim/oss: work correctly with LRScheduler (#58)
## [0.0.1] - 2020-07-31 ## [0.0.1] - 2020-07-31
......
...@@ -24,7 +24,12 @@ import torch.nn.functional as F ...@@ -24,7 +24,12 @@ import torch.nn.functional as F
from fairscale.nn.misc import FlattenParamsWrapper from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap
from fairscale.utils.containers import apply_to_tensors from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group from fairscale.utils.parallel import (
chunk_and_pad,
enable_pytorch_sync_bn,
get_process_group_cached,
validate_process_group,
)
from fairscale.utils.params import broadcast_object, calc_grad_norm, recursive_copy_to_device from fairscale.utils.params import broadcast_object, calc_grad_norm, recursive_copy_to_device
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_ from fairscale.utils.state_dict import replace_by_prefix_
...@@ -233,7 +238,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -233,7 +238,7 @@ class FullyShardedDataParallel(nn.Module):
): ):
init_start = time.time() init_start = time.time()
super().__init__() super().__init__()
self.process_group = process_group or dist.new_group() self.process_group = process_group or get_process_group_cached()
self.rank = self.process_group.rank() self.rank = self.process_group.rank()
self.world_size = self.process_group.size() self.world_size = self.process_group.size()
self.reshard_after_forward = reshard_after_forward self.reshard_after_forward = reshard_after_forward
...@@ -1816,13 +1821,11 @@ def auto_wrap_bn( ...@@ -1816,13 +1821,11 @@ def auto_wrap_bn(
module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES) # type: ignore module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES) # type: ignore
) )
pg = None pg = process_group
if single_rank_pg: if single_rank_pg:
# No sharding with this single member group. # No sharding with this single member group.
my_rank = dist.get_rank() my_rank = dist.get_rank()
pg = dist.new_group(ranks=[my_rank]) pg = get_process_group_cached(ranks=[my_rank])
else:
pg = process_group
if fsdp_config is None: if fsdp_config is None:
fsdp_config = { fsdp_config = {
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"""Useful functions for parallel training.""" """Useful functions for parallel training."""
from typing import List from typing import List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -58,18 +58,44 @@ def enable_pytorch_sync_bn(module: torch.nn.Module) -> None: ...@@ -58,18 +58,44 @@ def enable_pytorch_sync_bn(module: torch.nn.Module) -> None:
layer._specify_ddp_gpu_num(1) # type: ignore layer._specify_ddp_gpu_num(1) # type: ignore
def get_global_group() -> None: def get_process_group_cached(ranks: Optional[List[int]] = None) -> ProcessGroup:
""" """
Singleton PyTorch distributed group. Singleton PyTorch distributed group cache. Inspired by the code from fairseq.
Inspired by https://github.com/pytorch/fairseq
For FSDP, it is important to use a global group, otherwise, inner FSDP instances For FSDP, it is important to use the same group between outer and inner FSDP instances,
will not share the gradient reduction bucket buffer with the root instance, end up using otherwise, inner FSDP instances will not share the gradient reduction bucket buffer with
more GPU memory. the root instance. This will result in increased GPU memory utilization.
Each separate process group also uses separate NCCL library instances, which will have
a significant effect on GPU memory use if too many process groups are created and used.
Setting NCCL_BUFFSIZE=102400 env variable is a useful technique to check if the NCCL
memory is causing GPU OOM. Note, the NCCL buffers are not allocated
through the PyTorch caching allocator, therefore, you may see GPU OOM even when
torch.cuda.reserved_memory() is still way below the total amount of GPU memory.
Extra process groups can also reduce training speed (observed on VISSL models).
Args:
ranks (Optional[List[int]]):
Ranks requested in the target group. None for all ranks.
Default: None
Returns:
(ProcessGroup):
Return the requested process group. Throws RuntimeError if torch.distributed module is not yet initialized.
""" """
if dist.is_initialized(): if not dist.is_initialized():
if not hasattr(get_global_group, "_global_group"): raise RuntimeError("torch.distributed is not yet initialized but process group is requested.")
get_global_group._global_group = dist.new_group() # type: ignore
return get_global_group._global_group # type: ignore if not hasattr(get_process_group_cached, "_global_group_cache"):
else: get_process_group_cached._global_group_cache = {} # type: ignore
return None cache = get_process_group_cached._global_group_cache # type: ignore
if ranks is None:
ranks = list(range(dist.get_world_size()))
ranks_set = frozenset(ranks) # take care of ordering and duplicates in the ranks list.
if ranks_set not in cache:
cache[ranks_set] = dist.new_group(list(ranks_set))
return cache[ranks_set]
...@@ -21,7 +21,7 @@ import torch.optim as optim ...@@ -21,7 +21,7 @@ import torch.optim as optim
from fairscale.nn import checkpoint_wrapper from fairscale.nn import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import auto_wrap_bn from fairscale.nn.data_parallel import auto_wrap_bn
from fairscale.utils.parallel import get_global_group from fairscale.utils.parallel import get_process_group_cached
from fairscale.utils.testing import ( from fairscale.utils.testing import (
dist_init, dist_init,
dump_all_tensors, dump_all_tensors,
...@@ -33,7 +33,7 @@ from fairscale.utils.testing import ( ...@@ -33,7 +33,7 @@ from fairscale.utils.testing import (
def to_fsdp(module, fsdp_config): def to_fsdp(module, fsdp_config):
return FSDP(module, process_group=get_global_group(), **fsdp_config) return FSDP(module, process_group=get_process_group_cached(), **fsdp_config)
def get_cur_mem(rank, result, prefix): def get_cur_mem(rank, result, prefix):
...@@ -199,78 +199,78 @@ def test_fsdp_memory(fsdp, ckpt): ...@@ -199,78 +199,78 @@ def test_fsdp_memory(fsdp, ckpt):
"iter 0: start": 3, "iter 0: start": 3,
"iter 0: after fwd": 340, "iter 0: after fwd": 340,
"iter 0: after loss": 340, "iter 0: after loss": 340,
"iter 0: after bwd": 66, "iter 0: after bwd": 16,
"iter 0: after step": 68, "iter 0: after step": 18,
"iter 0: done": 5, "iter 0: done": 5,
"iter 1: start": 5, "iter 1: start": 5,
"iter 1: after fwd": 342, "iter 1: after fwd": 342,
"iter 1: after loss": 342, "iter 1: after loss": 342,
"iter 1: after bwd": 68, "iter 1: after bwd": 18,
"iter 1: after step": 68, "iter 1: after step": 18,
"iter 1: done": 5, "iter 1: done": 5,
"iter 2: start": 5, "iter 2: start": 5,
"iter 2: after fwd": 342, "iter 2: after fwd": 342,
"iter 2: after loss": 342, "iter 2: after loss": 342,
"iter 2: after bwd": 68, "iter 2: after bwd": 18,
"iter 2: after step": 68, "iter 2: after step": 18,
"iter 2: done": 5, "iter 2: done": 5,
"iter 3: start": 5, "iter 3: start": 5,
"iter 3: after fwd": 342, "iter 3: after fwd": 342,
"iter 3: after loss": 342, "iter 3: after loss": 342,
"iter 3: after bwd": 68, "iter 3: after bwd": 18,
"iter 3: after step": 68, "iter 3: after step": 18,
"iter 3: done": 5, "iter 3: done": 5,
}, },
("fsdp_amp_default", "no_ckpt"): { ("fsdp_amp_default", "no_ckpt"): {
"iter 0: start": 28, "iter 0: start": 28,
"iter 0: after fwd": 630, "iter 0: after fwd": 630,
"iter 0: after loss": 630, "iter 0: after loss": 630,
"iter 0: after bwd": 104, "iter 0: after bwd": 67,
"iter 0: after step": 131, "iter 0: after step": 93,
"iter 0: done": 54, "iter 0: done": 54,
"iter 1: start": 54, "iter 1: start": 54,
"iter 1: after fwd": 657, "iter 1: after fwd": 657,
"iter 1: after loss": 657, "iter 1: after loss": 657,
"iter 1: after bwd": 131, "iter 1: after bwd": 93,
"iter 1: after step": 131, "iter 1: after step": 93,
"iter 1: done": 54, "iter 1: done": 54,
"iter 2: start": 54, "iter 2: start": 54,
"iter 2: after fwd": 657, "iter 2: after fwd": 657,
"iter 2: after loss": 657, "iter 2: after loss": 657,
"iter 2: after bwd": 131, "iter 2: after bwd": 93,
"iter 2: after step": 131, "iter 2: after step": 93,
"iter 2: done": 54, "iter 2: done": 54,
"iter 3: start": 54, "iter 3: start": 54,
"iter 3: after fwd": 657, "iter 3: after fwd": 657,
"iter 3: after loss": 657, "iter 3: after loss": 657,
"iter 3: after bwd": 131, "iter 3: after bwd": 93,
"iter 3: after step": 131, "iter 3: after step": 93,
"iter 3: done": 54, "iter 3: done": 54,
}, },
("fsdp_amp_compute_dtype32", "no_ckpt"): { ("fsdp_amp_compute_dtype32", "no_ckpt"): {
"iter 0: start": 28, "iter 0: start": 28,
"iter 0: after fwd": 657, "iter 0: after fwd": 657,
"iter 0: after loss": 657, "iter 0: after loss": 657,
"iter 0: after bwd": 117, "iter 0: after bwd": 67,
"iter 0: after step": 143, "iter 0: after step": 93,
"iter 0: done": 54, "iter 0: done": 54,
"iter 1: start": 54, "iter 1: start": 54,
"iter 1: after fwd": 684, "iter 1: after fwd": 684,
"iter 1: after loss": 684, "iter 1: after loss": 684,
"iter 1: after bwd": 143, "iter 1: after bwd": 93,
"iter 1: after step": 143, "iter 1: after step": 93,
"iter 1: done": 54, "iter 1: done": 54,
"iter 2: start": 54, "iter 2: start": 54,
"iter 2: after fwd": 684, "iter 2: after fwd": 684,
"iter 2: after loss": 684, "iter 2: after loss": 684,
"iter 2: after bwd": 143, "iter 2: after bwd": 93,
"iter 2: after step": 143, "iter 2: after step": 93,
"iter 2: done": 54, "iter 2: done": 54,
"iter 3: start": 54, "iter 3: start": 54,
"iter 3: after fwd": 684, "iter 3: after fwd": 684,
"iter 3: after loss": 684, "iter 3: after loss": 684,
"iter 3: after bwd": 143, "iter 3: after bwd": 93,
"iter 3: after step": 143, "iter 3: after step": 93,
"iter 3: done": 54, "iter 3: done": 54,
}, },
("ddp", "ckpt"): { ("ddp", "ckpt"): {
...@@ -303,78 +303,78 @@ def test_fsdp_memory(fsdp, ckpt): ...@@ -303,78 +303,78 @@ def test_fsdp_memory(fsdp, ckpt):
"iter 0: start": 3, "iter 0: start": 3,
"iter 0: after fwd": 51, "iter 0: after fwd": 51,
"iter 0: after loss": 51, "iter 0: after loss": 51,
"iter 0: after bwd": 66, "iter 0: after bwd": 16,
"iter 0: after step": 68, "iter 0: after step": 18,
"iter 0: done": 5, "iter 0: done": 5,
"iter 1: start": 5, "iter 1: start": 5,
"iter 1: after fwd": 53, "iter 1: after fwd": 53,
"iter 1: after loss": 53, "iter 1: after loss": 53,
"iter 1: after bwd": 68, "iter 1: after bwd": 18,
"iter 1: after step": 68, "iter 1: after step": 18,
"iter 1: done": 5, "iter 1: done": 5,
"iter 2: start": 5, "iter 2: start": 5,
"iter 2: after fwd": 53, "iter 2: after fwd": 53,
"iter 2: after loss": 53, "iter 2: after loss": 53,
"iter 2: after bwd": 68, "iter 2: after bwd": 18,
"iter 2: after step": 68, "iter 2: after step": 18,
"iter 2: done": 5, "iter 2: done": 5,
"iter 3: start": 5, "iter 3: start": 5,
"iter 3: after fwd": 53, "iter 3: after fwd": 53,
"iter 3: after loss": 53, "iter 3: after loss": 53,
"iter 3: after bwd": 68, "iter 3: after bwd": 18,
"iter 3: after step": 68, "iter 3: after step": 18,
"iter 3: done": 5, "iter 3: done": 5,
}, },
("fsdp_amp_default", "ckpt"): { ("fsdp_amp_default", "ckpt"): {
"iter 0: start": 28, "iter 0: start": 28,
"iter 0: after fwd": 52, "iter 0: after fwd": 52,
"iter 0: after loss": 52, "iter 0: after loss": 52,
"iter 0: after bwd": 104, "iter 0: after bwd": 67,
"iter 0: after step": 131, "iter 0: after step": 93,
"iter 0: done": 54, "iter 0: done": 54,
"iter 1: start": 54, "iter 1: start": 54,
"iter 1: after fwd": 79, "iter 1: after fwd": 79,
"iter 1: after loss": 79, "iter 1: after loss": 79,
"iter 1: after bwd": 131, "iter 1: after bwd": 93,
"iter 1: after step": 131, "iter 1: after step": 93,
"iter 1: done": 54, "iter 1: done": 54,
"iter 2: start": 54, "iter 2: start": 54,
"iter 2: after fwd": 79, "iter 2: after fwd": 79,
"iter 2: after loss": 79, "iter 2: after loss": 79,
"iter 2: after bwd": 131, "iter 2: after bwd": 93,
"iter 2: after step": 131, "iter 2: after step": 93,
"iter 2: done": 54, "iter 2: done": 54,
"iter 3: start": 54, "iter 3: start": 54,
"iter 3: after fwd": 79, "iter 3: after fwd": 79,
"iter 3: after loss": 79, "iter 3: after loss": 79,
"iter 3: after bwd": 131, "iter 3: after bwd": 93,
"iter 3: after step": 131, "iter 3: after step": 93,
"iter 3: done": 54, "iter 3: done": 54,
}, },
("fsdp_amp_compute_dtype32", "ckpt"): { ("fsdp_amp_compute_dtype32", "ckpt"): {
"iter 0: start": 28, "iter 0: start": 28,
"iter 0: after fwd": 52, "iter 0: after fwd": 52,
"iter 0: after loss": 52, "iter 0: after loss": 52,
"iter 0: after bwd": 117, "iter 0: after bwd": 67,
"iter 0: after step": 143, "iter 0: after step": 93,
"iter 0: done": 54, "iter 0: done": 54,
"iter 1: start": 54, "iter 1: start": 54,
"iter 1: after fwd": 79, "iter 1: after fwd": 79,
"iter 1: after loss": 79, "iter 1: after loss": 79,
"iter 1: after bwd": 143, "iter 1: after bwd": 93,
"iter 1: after step": 143, "iter 1: after step": 93,
"iter 1: done": 54, "iter 1: done": 54,
"iter 2: start": 54, "iter 2: start": 54,
"iter 2: after fwd": 79, "iter 2: after fwd": 79,
"iter 2: after loss": 79, "iter 2: after loss": 79,
"iter 2: after bwd": 143, "iter 2: after bwd": 93,
"iter 2: after step": 143, "iter 2: after step": 93,
"iter 2: done": 54, "iter 2: done": 54,
"iter 3: start": 54, "iter 3: start": 54,
"iter 3: after fwd": 79, "iter 3: after fwd": 79,
"iter 3: after loss": 79, "iter 3: after loss": 79,
"iter 3: after bwd": 143, "iter 3: after bwd": 93,
"iter 3: after step": 143, "iter 3: after step": 93,
"iter 3: done": 54, "iter 3: done": 54,
}, },
}[(fsdp, ckpt)] }[(fsdp, ckpt)]
......
...@@ -62,7 +62,7 @@ def expert_params(device): ...@@ -62,7 +62,7 @@ def expert_params(device):
expert = torch.nn.Linear(model_dim, model_dim) expert = torch.nn.Linear(model_dim, model_dim)
moe = MOELayer(gate, expert).to(device) moe = MOELayer(gate, expert).to(device)
for p in expert.parameters(): for p in expert.parameters():
assert p.expert is True assert p.expert is True, str(p.expert)
@pg_test() @pg_test()
...@@ -77,9 +77,9 @@ def forward(device): ...@@ -77,9 +77,9 @@ def forward(device):
expert.weight = torch.nn.Parameter(torch.eye(model_dim)) expert.weight = torch.nn.Parameter(torch.eye(model_dim))
moe = MOELayer(gate, expert).to(device) moe = MOELayer(gate, expert).to(device)
output = moe(input) output = moe(input)
assert output.shape == input.shape assert output.shape == input.shape, f"{output.shape} != {input.shape}"
# Re-assembled output should match input due to identity expert. # Re-assembled output should match input due to identity expert.
assert torch.allclose(input, output) torch.testing.assert_allclose(input, output)
@pg_test() @pg_test()
...@@ -99,11 +99,13 @@ def forward_multi(device): ...@@ -99,11 +99,13 @@ def forward_multi(device):
experts += [expert] experts += [expert]
moe = MOELayer(gate, torch.nn.ModuleList(experts)).to(device) moe = MOELayer(gate, torch.nn.ModuleList(experts)).to(device)
output = moe(input) output = moe(input)
assert output.shape == input.shape assert output.shape == input.shape, f"{output.shape} != {input.shape}"
# 90% of the input should have gone to an expert # 90% of the input should have gone to an expert
assert len(output.nonzero(as_tuple=False)) / output.numel() > 0.90 assert (
len(output.nonzero(as_tuple=False)) / output.numel() > 0.90
), f"{len(output.nonzero(as_tuple=False))} / {output.numel()}"
# Except for zeros, re-assembled output should match input due to identity expert. # Except for zeros, re-assembled output should match input due to identity expert.
assert torch.allclose(input, torch.where(output > 0, output, input)) torch.testing.assert_allclose(input, torch.where(output > 0, output, input))
# Test Gate which round-robin routes tokens to experts # Test Gate which round-robin routes tokens to experts
...@@ -115,7 +117,7 @@ class RoundRobinGate(torch.nn.Module): ...@@ -115,7 +117,7 @@ class RoundRobinGate(torch.nn.Module):
def forward(self, input): def forward(self, input):
s = input.shape[0] s = input.shape[0]
assert s % self.num_experts == 0 assert s % self.num_experts == 0, f"{s} % {self.num_experts} != 0"
capacity = 2 * s // self.num_experts capacity = 2 * s // self.num_experts
output = torch.zeros(s, self.num_experts, capacity, dtype=input.dtype, device=input.device) output = torch.zeros(s, self.num_experts, capacity, dtype=input.dtype, device=input.device)
for i in range(s): for i in range(s):
...@@ -136,12 +138,12 @@ def forward_routing(device): ...@@ -136,12 +138,12 @@ def forward_routing(device):
expert.weight = torch.nn.Parameter(torch.eye(model_dim) * scale) expert.weight = torch.nn.Parameter(torch.eye(model_dim) * scale)
moe = MOELayer(gate, expert).to(device) moe = MOELayer(gate, expert).to(device)
output = moe(input) output = moe(input)
assert output.shape == input.shape assert output.shape == input.shape, f"{output.shape} != {input.shape}"
# Verify that each token was sent to the correct expert by checking its scale. # Verify that each token was sent to the correct expert by checking its scale.
t = input.shape[1] t = input.shape[1]
for i in range(t): for i in range(t):
expert = i % num_experts expert = i % num_experts
assert torch.allclose(input[:, i] * (expert + 1), output[:, i]) torch.testing.assert_allclose(input[:, i] * (expert + 1), output[:, i])
@pg_test() @pg_test()
...@@ -161,12 +163,12 @@ def forward_routing_multi(device): ...@@ -161,12 +163,12 @@ def forward_routing_multi(device):
experts += [expert] experts += [expert]
moe = MOELayer(gate, torch.nn.ModuleList(experts)).to(device) moe = MOELayer(gate, torch.nn.ModuleList(experts)).to(device)
output = moe(input) output = moe(input)
assert output.shape == input.shape assert output.shape == input.shape, f"{output.shape} != {input.shape}"
# Verify that each token was sent to the correct expert by checking its scale. # Verify that each token was sent to the correct expert by checking its scale.
t = input.shape[1] t = input.shape[1]
for i in range(t): for i in range(t):
expert = i % num_experts expert = i % num_experts
assert torch.allclose(input[:, i] * (expert + 1), output[:, i]) torch.testing.assert_allclose(input[:, i] * (expert + 1), output[:, i])
@pg_test() @pg_test()
...@@ -182,7 +184,7 @@ def backward(device): ...@@ -182,7 +184,7 @@ def backward(device):
expert.weight = torch.nn.Parameter(torch.eye(model_dim)) expert.weight = torch.nn.Parameter(torch.eye(model_dim))
moe = MOELayer(gate, expert).to(device) moe = MOELayer(gate, expert).to(device)
output = moe(input) output = moe(input)
assert output.shape == input.shape assert output.shape == input.shape, f"{output.shape} != {input.shape}"
output = loss(output, input) output = loss(output, input)
output.backward() output.backward()
assert torch.allclose(expert.weight.grad, torch.zeros_like(expert.weight)) torch.testing.assert_allclose(expert.weight.grad, torch.zeros_like(expert.weight))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment