# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import functools from time import time import unittest from parameterized import parameterized import torch from torch.optim import SGD, Adadelta, Adam # type: ignore from fairscale.nn import FullyShardedDataParallel from fairscale.nn.data_parallel.fsdp_optim_utils import is_singleton_tensor from fairscale.utils.params import recursive_copy_to_device from fairscale.utils.testing import dist_init, objects_are_equal, spawn_for_all_world_sizes from .test_fsdp import ( DistributedTest, DummyProcessGroup, MixtureOfExperts, TransformerWithSharedParams, rename_test, spawn_and_init, ) def all_tensors_numel_except_for_step(dct): """Compute the sum of numel from all tensors from a dict, except when the key is `step`.""" ret = 0 for k, v in dct.items(): if k != "step" and torch.is_tensor(v): ret += v.numel() return ret def assert_equal(a, b): assert a == b, f"{a} != {b}" def spawn_and_init_multiple_groups(fn, args=None, **spawn_kwargs): if args is None: args = () run_fn = functools.partial(init_and_run, fn, args) spawn_for_all_world_sizes(run_fn, **spawn_kwargs) def _find_my_group_index(grouped_ranks): """Return the index corresponding to the MoE group of the current process.""" my_rank = torch.distributed.get_rank() for i, group in enumerate(grouped_ranks): if my_rank in group: return i raise RuntimeError(f"Unable to find process rank {my_rank} in the set of grouped ranks {grouped_ranks}.") def get_moe_group(moe_expert_count=2): """Return a process group for initializing a MoE layer.""" if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() # If you have more experts than the world size. if world_size <= moe_expert_count: assert moe_expert_count % world_size == 0 moe_groups = [[i] for i in range(world_size)] # If you have a larger world size than experts. else: assert world_size % moe_expert_count == 0 ranks_per_group = world_size // moe_expert_count moe_groups = [[i + j * moe_expert_count for j in range(ranks_per_group)] for i in range(moe_expert_count)] moe_pgs = [torch.distributed.new_group(g) for g in moe_groups] # Find the index in the set of moe_groups which contains the current rank. my_group_idx = _find_my_group_index(moe_groups) return moe_pgs[my_group_idx] else: return torch.distributed.new_group([torch.distributed.get_rank()]) def init_and_run(fn, args, rank, world_size, filename, filename_rpc): """Initialize and run the unit test for testing replicated MoE groups.""" dist_init(rank, world_size, filename, filename_rpc) torch.cuda.set_device(rank) group = torch.distributed.new_group() # Specify the moe_group used to initialize the MoE layers with. fn(rank, group, *args, expert_group=get_moe_group()) class TestOptimizerUtils(DistributedTest): @parameterized.expand( [[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False], [Adadelta, True], [Adam, True]], name_func=rename_test, ) def test_consolidate_optimizer(self, optim_fn, transformer): config = {"mixed_precision": True, "flatten_parameters": True} config["compute_dtype"] = torch.float32 test_fn = functools.partial( self._test_consolidated_optimizer, config, optim_fn=optim_fn, transformer=transformer ) spawn_and_init(test_fn, world_sizes=[min(torch.cuda.device_count(), 4)]) @parameterized.expand( [[SGD, False], [Adam, False]], name_func=rename_test, ) def test_consolidate_optimizer_diff_world_size(self, optim_fn, transformer): if torch.cuda.device_count() < 4: raise unittest.SkipTest("This test requires at least 4 GPUs.") config = {"mixed_precision": True, "flatten_parameters": True} config["compute_dtype"] = torch.float32 test_fn = functools.partial(self._test_consolidated_optimizer, config, optim_fn=Adam, transformer=transformer) spawn_and_init_multiple_groups(test_fn, world_sizes=[min(torch.cuda.device_count(), 4)]) @classmethod def _test_consolidated_optimizer( self, config, rank, group, optim_fn=torch.optim.SGD, transformer=False, expert_group=None ): """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()""" # Establish reference behavior. if transformer: unwrapped_model = TransformerWithSharedParams(group, wrapper_config=config).cuda() fsdp = self.get_wrapped_model(group, config=config).cuda() else: unwrapped_model = MixtureOfExperts(group, wrapper_config=None, expert_group=expert_group).cuda() fsdp = FullyShardedDataParallel( MixtureOfExperts(group, wrapper_config=config, expert_group=expert_group) ).cuda() try: fsdp_optim = optim_fn( fsdp.parameters(), lr=0.01, ) optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01) except TypeError: # Adadelta fsdp_optim = optim_fn(fsdp.parameters()) optim_unwrapped = optim_fn(unwrapped_model.parameters()) fsdp_optim.zero_grad() optim_unwrapped.zero_grad() with torch.cuda.amp.autocast(enabled=True): x = fsdp.module.get_input(torch.device("cuda")) output = fsdp(*x) loss = fsdp.module.get_loss(x, output).to("cuda") fsdp.module.run_backward(loss) fsdp_optim.step() output = unwrapped_model(*x) loss = unwrapped_model.get_loss(x, output) unwrapped_model.run_backward(loss) optim_unwrapped.step() unwrapped_sd = optim_unwrapped.state_dict() if not transformer and not expert_group: no_broadcast_children = [x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state] assert len(no_broadcast_children) == 1, f"Length of non shared params {len(no_broadcast_children)}" assert fsdp._fsdp_instances[-1].no_broadcast_optim_state torch.cuda.empty_cache() cuda_gb_before = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3 tstart = time() sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0) duration = time() - tstart assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate" cuda_gb_after = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3 mem_usg_gb = cuda_gb_after - cuda_gb_before assert mem_usg_gb == 0, f"gather_full_optim_state_dict used {mem_usg_gb:.2f} CUDA GB, max allowed is 0" assert cuda_gb_after > 0, "got 0 memory usage, logging is broken" if fsdp.rank > 0: assert sd is None return # assert whole state dict on CPU for k, v in sd["state"].items(): for buffer_name, t in v.items(): if torch.is_tensor(t): msg = f"got device {t.device} for {k}: {buffer_name}. expected CPU" assert t.device == torch.device("cpu"), msg if expert_group: sd_state = recursive_copy_to_device(sd["state"], non_blocking=False, device="cpu") orig_state = recursive_copy_to_device(unwrapped_sd["state"], non_blocking=False, device="cpu") assert_equal(len(sd_state.keys()), len(orig_state.keys())) assert_equal( sum([all_tensors_numel_except_for_step(v) for k, v in sd_state.items()]), sum([all_tensors_numel_except_for_step(v) for k, v in orig_state.items()]), ) return unflat_state = sd["state"] assert "uncollected_local_ids" in sd shard_sd = fsdp.get_shard_from_optim_state_dict(sd) shard_sd = recursive_copy_to_device(shard_sd, non_blocking=False, device="cpu") state_after_get_shard = sd["state"] assert objects_are_equal(unflat_state, state_after_get_shard) # no side effects. assert_equal(len(sd["state"]), len(unwrapped_sd["state"])) assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"])) assert_equal( sum([all_tensors_numel_except_for_step(v) for k, v in sd["state"].items()]), sum([all_tensors_numel_except_for_step(v) for k, v in unwrapped_sd["state"].items()]), ) original_shard_sd = fsdp_optim.state_dict() assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"])) assert_equal(shard_sd.keys(), original_shard_sd.keys()) original_shard_sd = recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu") # Before asserting that the dicts are equal, we check keys individually to allow nice tracebacks. assert_equal( [all_tensors_numel_except_for_step(v) for k, v in shard_sd["state"].items()], [all_tensors_numel_except_for_step(v) for k, v in original_shard_sd["state"].items()], ) assert_equal( [v for k, v in shard_sd["param_groups"][0].items()], [v for k, v in original_shard_sd["param_groups"][0].items()], ) assert objects_are_equal(shard_sd["state"], original_shard_sd["state"]) assert objects_are_equal({k: shard_sd[k] for k in original_shard_sd}, original_shard_sd) def test_named_params_ordering(self): """Test assumption of consolidate_optimizer_state_dict""" group = DummyProcessGroup(0, 1) model = TransformerWithSharedParams(group) named_pars = [p for n, p in model.named_parameters()] for i, p in enumerate(model.parameters()): assert objects_are_equal(p, named_pars[i]) def test_is_singleton_tensor(self): assert is_singleton_tensor(torch.tensor(4.0)) assert not is_singleton_tensor(torch.tensor([4.0])) assert not is_singleton_tensor(torch.tensor([4.0, 5.0])) assert not is_singleton_tensor([4.0]) assert not is_singleton_tensor(4.0)