You need to sign in or sign up before continuing.
Unverified Commit 40e7450f authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[fix][FSDP] Add support for saving optimizer state with expert replication (#936)

* checkpoint tests

* checkpoint tests

* fix tests

* lint fixes

* remove prints

* lint fixes

* add comments

* add changelog

* more cleanup

* lint fix
parent cb72ae54
...@@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- FSDP: Added skip_params_check_for_root flag to default_auto_wrap_policy which, - FSDP: Added skip_params_check_for_root flag to default_auto_wrap_policy which,
if set, wraps the root module regardless of how many unwrapped params there were if set, wraps the root module regardless of how many unwrapped params there were
left after children were wrapped. [#930] left after children were wrapped. [#930]
- FSDP: Add support for saving optimizer state when using expert replicas with FSDP.
## [0.4.5] - 2022-01-14 ## [0.4.5] - 2022-01-14
......
...@@ -2238,6 +2238,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2238,6 +2238,11 @@ class FullyShardedDataParallel(nn.Module):
# param that has the optimizer state. So we handle it with the correct # param that has the optimizer state. So we handle it with the correct
# parameter list. # parameter list.
non_shared_params = cast(FullyShardedDataParallel, self._fsdp_instances[k]).non_shared_params() non_shared_params = cast(FullyShardedDataParallel, self._fsdp_instances[k]).non_shared_params()
# This is the world size and process group of the FSDP submodule which can be
# different than the parent module. For example, when FSDP is used with MoE.
non_shared_world_size = self._fsdp_instances[k].world_size
non_shared_process_group = self._fsdp_instances[k].process_group
assert ( assert (
len(non_shared_params) == 1 len(non_shared_params) == 1
), f"Only flatten param or a single non-shared param is supported: len={len(non_shared_params)}" ), f"Only flatten param or a single non-shared param is supported: len={len(non_shared_params)}"
...@@ -2250,15 +2255,15 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2250,15 +2255,15 @@ class FullyShardedDataParallel(nn.Module):
if ou.is_singleton_tensor(t): if ou.is_singleton_tensor(t):
if singleton_buffer is None: if singleton_buffer is None:
singleton_buffer = list(t.new_zeros(self.world_size).chunk(self.world_size)) singleton_buffer = list(t.new_zeros(non_shared_world_size).chunk(non_shared_world_size))
dist.all_gather(singleton_buffer, t, group=self.process_group) dist.all_gather(singleton_buffer, t, group=non_shared_process_group)
if self.rank == 0: if self.rank == 0:
singleton_state[k][buffer_name] = [x.cpu().squeeze() for x in singleton_buffer] singleton_state[k][buffer_name] = [x.cpu().squeeze() for x in singleton_buffer]
assert ou.is_singleton_tensor(singleton_state[k][buffer_name][0]) assert ou.is_singleton_tensor(singleton_state[k][buffer_name][0])
elif torch.is_tensor(t): elif torch.is_tensor(t):
if buffer is None: if buffer is None:
buffer = list(t.new_zeros(*desired_buffer_size).chunk(self.world_size)) buffer = list(t.new_zeros(*desired_buffer_size).chunk(non_shared_world_size))
dist.all_gather(buffer, t, group=self.process_group) dist.all_gather(buffer, t, group=non_shared_process_group)
if self.rank == 0: if self.rank == 0:
gathered_state[k][buffer_name] = [x.cpu() for x in buffer] gathered_state[k][buffer_name] = [x.cpu() for x in buffer]
elif self.rank == 0: # Add non tensor state elif self.rank == 0: # Add non tensor state
......
...@@ -775,7 +775,7 @@ class DummyDDP(nn.Module): ...@@ -775,7 +775,7 @@ class DummyDDP(nn.Module):
class MixtureOfExperts(NestedWrappedModule): class MixtureOfExperts(NestedWrappedModule):
def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_free_ms=0): def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_free_ms=0, expert_group=None):
super().__init__(group, wrapper_config) super().__init__(group, wrapper_config)
self.group = group self.group = group
self.delay_before_free_ms = delay_before_free_ms self.delay_before_free_ms = delay_before_free_ms
...@@ -801,9 +801,9 @@ class MixtureOfExperts(NestedWrappedModule): ...@@ -801,9 +801,9 @@ class MixtureOfExperts(NestedWrappedModule):
shared = checkpoint_wrapper(shared) shared = checkpoint_wrapper(shared)
if wrapper_config is not None: if wrapper_config is not None:
# we create a process group of size 1 for the expert params # we create a process group of size >= 1 for the expert params
# we also need to pass that group as the reduce_scatter group. # we also need to pass that group as the reduce_scatter group.
expert_group = torch.distributed.new_group([group.rank()]) expert_group = expert_group or torch.distributed.new_group([group.rank()])
expert = FullyShardedDataParallel( expert = FullyShardedDataParallel(
expert, process_group=expert_group, process_group_reduce_scatter=expert_group, **wrapper_config expert, process_group=expert_group, process_group_reduce_scatter=expert_group, **wrapper_config
) )
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import functools import functools
from time import time from time import time
import unittest
from parameterized import parameterized from parameterized import parameterized
import torch import torch
...@@ -12,7 +13,7 @@ from torch.optim import SGD, Adadelta, Adam # type: ignore ...@@ -12,7 +13,7 @@ from torch.optim import SGD, Adadelta, Adam # type: ignore
from fairscale.nn import FullyShardedDataParallel from fairscale.nn import FullyShardedDataParallel
from fairscale.nn.data_parallel.fsdp_optim_utils import is_singleton_tensor from fairscale.nn.data_parallel.fsdp_optim_utils import is_singleton_tensor
from fairscale.utils.params import recursive_copy_to_device from fairscale.utils.params import recursive_copy_to_device
from fairscale.utils.testing import objects_are_equal from fairscale.utils.testing import dist_init, objects_are_equal, spawn_for_all_world_sizes
from .test_fsdp import ( from .test_fsdp import (
DistributedTest, DistributedTest,
...@@ -37,6 +38,57 @@ def assert_equal(a, b): ...@@ -37,6 +38,57 @@ def assert_equal(a, b):
assert a == b, f"{a} != {b}" assert a == b, f"{a} != {b}"
def spawn_and_init_multiple_groups(fn, args=None, **spawn_kwargs):
if args is None:
args = ()
run_fn = functools.partial(init_and_run, fn, args)
spawn_for_all_world_sizes(run_fn, **spawn_kwargs)
def _find_my_group_index(grouped_ranks):
"""Return the index corresponding to the MoE group of the current process."""
my_rank = torch.distributed.get_rank()
for i, group in enumerate(grouped_ranks):
if my_rank in group:
return i
raise RuntimeError(f"Unable to find process rank {my_rank} in the set of grouped ranks {grouped_ranks}.")
def get_moe_group(moe_expert_count=2):
"""Return a process group for initializing a MoE layer."""
if torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
# If you have more experts than the world size.
if world_size <= moe_expert_count:
assert moe_expert_count % world_size == 0
moe_groups = [[i] for i in range(world_size)]
# If you have a larger world size than experts.
else:
assert world_size % moe_expert_count == 0
ranks_per_group = world_size // moe_expert_count
moe_groups = [[i + j * moe_expert_count for j in range(ranks_per_group)] for i in range(moe_expert_count)]
moe_pgs = [torch.distributed.new_group(g) for g in moe_groups]
# Find the index in the set of moe_groups which contains the current rank.
my_group_idx = _find_my_group_index(moe_groups)
return moe_pgs[my_group_idx]
else:
return torch.distributed.new_group([torch.distributed.get_rank()])
def init_and_run(fn, args, rank, world_size, filename, filename_rpc):
"""Initialize and run the unit test for testing replicated MoE groups."""
dist_init(rank, world_size, filename, filename_rpc)
torch.cuda.set_device(rank)
group = torch.distributed.new_group()
# Specify the moe_group used to initialize the MoE layers with.
fn(rank, group, *args, expert_group=get_moe_group())
class TestOptimizerUtils(DistributedTest): class TestOptimizerUtils(DistributedTest):
@parameterized.expand( @parameterized.expand(
[[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False], [Adadelta, True], [Adam, True]], [[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False], [Adadelta, True], [Adam, True]],
...@@ -51,17 +103,33 @@ class TestOptimizerUtils(DistributedTest): ...@@ -51,17 +103,33 @@ class TestOptimizerUtils(DistributedTest):
spawn_and_init(test_fn, world_sizes=[min(torch.cuda.device_count(), 4)]) spawn_and_init(test_fn, world_sizes=[min(torch.cuda.device_count(), 4)])
@parameterized.expand(
[[SGD, False], [Adam, False]],
name_func=rename_test,
)
def test_consolidate_optimizer_diff_world_size(self, optim_fn, transformer):
if torch.cuda.device_count() < 4:
raise unittest.SkipTest("This test requires at least 4 GPUs.")
config = {"mixed_precision": True, "flatten_parameters": True}
config["compute_dtype"] = torch.float32
test_fn = functools.partial(self._test_consolidated_optimizer, config, optim_fn=Adam, transformer=transformer)
spawn_and_init_multiple_groups(test_fn, world_sizes=[min(torch.cuda.device_count(), 4)])
@classmethod @classmethod
def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim.SGD, transformer=False): def _test_consolidated_optimizer(
self, config, rank, group, optim_fn=torch.optim.SGD, transformer=False, expert_group=None
):
"""FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()""" """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()"""
# Establish reference behavior. # Establish reference behavior.
if transformer: if transformer:
unwrapped_model = TransformerWithSharedParams(group, wrapper_config=config).cuda() unwrapped_model = TransformerWithSharedParams(group, wrapper_config=config).cuda()
fsdp = self.get_wrapped_model(group, config=config).cuda() fsdp = self.get_wrapped_model(group, config=config).cuda()
else: else:
unwrapped_model = MixtureOfExperts(group, wrapper_config=None).cuda() unwrapped_model = MixtureOfExperts(group, wrapper_config=None, expert_group=expert_group).cuda()
fsdp = FullyShardedDataParallel(MixtureOfExperts(group, wrapper_config=config)).cuda() fsdp = FullyShardedDataParallel(
MixtureOfExperts(group, wrapper_config=config, expert_group=expert_group)
).cuda()
try: try:
fsdp_optim = optim_fn( fsdp_optim = optim_fn(
...@@ -88,9 +156,9 @@ class TestOptimizerUtils(DistributedTest): ...@@ -88,9 +156,9 @@ class TestOptimizerUtils(DistributedTest):
optim_unwrapped.step() optim_unwrapped.step()
unwrapped_sd = optim_unwrapped.state_dict() unwrapped_sd = optim_unwrapped.state_dict()
if not transformer: if not transformer and not expert_group:
no_broadcast_children = [x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state] no_broadcast_children = [x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state]
assert len(no_broadcast_children) == 1 assert len(no_broadcast_children) == 1, f"Length of non shared params {len(no_broadcast_children)}"
assert fsdp._fsdp_instances[-1].no_broadcast_optim_state assert fsdp._fsdp_instances[-1].no_broadcast_optim_state
torch.cuda.empty_cache() torch.cuda.empty_cache()
cuda_gb_before = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3 cuda_gb_before = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3
...@@ -115,6 +183,18 @@ class TestOptimizerUtils(DistributedTest): ...@@ -115,6 +183,18 @@ class TestOptimizerUtils(DistributedTest):
msg = f"got device {t.device} for {k}: {buffer_name}. expected CPU" msg = f"got device {t.device} for {k}: {buffer_name}. expected CPU"
assert t.device == torch.device("cpu"), msg assert t.device == torch.device("cpu"), msg
if expert_group:
sd_state = recursive_copy_to_device(sd["state"], non_blocking=False, device="cpu")
orig_state = recursive_copy_to_device(unwrapped_sd["state"], non_blocking=False, device="cpu")
assert_equal(len(sd_state.keys()), len(orig_state.keys()))
assert_equal(
sum([all_tensors_numel_except_for_step(v) for k, v in sd_state.items()]),
sum([all_tensors_numel_except_for_step(v) for k, v in orig_state.items()]),
)
return
unflat_state = sd["state"] unflat_state = sd["state"]
assert "uncollected_local_ids" in sd assert "uncollected_local_ids" in sd
shard_sd = fsdp.get_shard_from_optim_state_dict(sd) shard_sd = fsdp.get_shard_from_optim_state_dict(sd)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment