Unverified Commit 05ce7971 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[fix] FSDP: fix MoE corner case (fixes #467) (#501)

parent 02405740
...@@ -1127,7 +1127,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1127,7 +1127,7 @@ class FullyShardedDataParallel(nn.Module):
if params is None: if params is None:
params = self.params params = self.params
self.has_full_params = False self.has_full_params = False
current_stream = torch.cuda.current_stream() self._streams["all_gather"].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._streams["all_gather"]): with torch.cuda.stream(self._streams["all_gather"]):
for p in params: for p in params:
if not p._is_sharded: # e.g., world_size == 1 if not p._is_sharded: # e.g., world_size == 1
...@@ -1140,7 +1140,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1140,7 +1140,6 @@ class FullyShardedDataParallel(nn.Module):
# unshard parameters, we should reuse the original Tensor # unshard parameters, we should reuse the original Tensor
# Storage object and unshard it in-place. For now, just resize # Storage object and unshard it in-place. For now, just resize
# the Storage to 0 to save memory. # the Storage to 0 to save memory.
p._full_param_padded.record_stream(current_stream)
free_storage_(p._full_param_padded) free_storage_(p._full_param_padded)
@torch.no_grad() @torch.no_grad()
......
...@@ -287,6 +287,20 @@ class TestComparisonToPyTorchDDP(DistributedTest): ...@@ -287,6 +287,20 @@ class TestComparisonToPyTorchDDP(DistributedTest):
) )
spawn_and_init(test_fn) spawn_and_init(test_fn)
@parameterized.expand([[{"checkpoint_act": False}], [{"checkpoint_act": True}]], name_func=rename_test)
def test_mixture_of_experts_with_delay_before_free(self, moe_config):
fsdp_config = {"mixed_precision": True}
test_fn = functools.partial(
self._test_identical_outputs,
functools.partial(MixtureOfExperts, delay_before_free_ms=250, **moe_config),
fsdp_config,
# MixtureOfExperts implements custom reduce logic, so the reference
# behavior should use that logic instead of PyTorch DDP.
ref_ddp_fn=self._dummy_ddp_fn,
norm_type=None,
)
spawn_and_init(test_fn)
def test_mixture_of_experts_grad_clip_breaks(self): def test_mixture_of_experts_grad_clip_breaks(self):
config = {"mixed_precision": True} config = {"mixed_precision": True}
test_fn = functools.partial( test_fn = functools.partial(
...@@ -760,9 +774,10 @@ class DummyDDP(nn.Module): ...@@ -760,9 +774,10 @@ class DummyDDP(nn.Module):
class MixtureOfExperts(NestedWrappedModule): class MixtureOfExperts(NestedWrappedModule):
def __init__(self, group, wrapper_config, checkpoint_act=False): def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_free_ms=0):
super().__init__(group, wrapper_config) super().__init__(group, wrapper_config)
self.group = group self.group = group
self.delay_before_free_ms = delay_before_free_ms
# "expert" params are different on each rank # "expert" params are different on each rank
torch.manual_seed(42 + group.rank()) torch.manual_seed(42 + group.rank())
...@@ -787,6 +802,22 @@ class MixtureOfExperts(NestedWrappedModule): ...@@ -787,6 +802,22 @@ class MixtureOfExperts(NestedWrappedModule):
self.module = nn.Sequential(nn.Linear(8, 4), shared, expert, nn.Linear(4, 8)) self.module = nn.Sequential(nn.Linear(8, 4), shared, expert, nn.Linear(4, 8))
def forward(self, x):
if self.delay_before_free_ms > 0:
expert = self.module[2]
if isinstance(expert, FullyShardedDataParallel):
orig_free_full_params = self.module[2]._free_full_params
def _free_full_params_with_delay(*args):
torch.cuda._sleep(int(self.delay_before_free_ms * get_cycles_per_ms()))
return orig_free_full_params(*args)
assert hasattr(expert, "_free_full_params")
with mock.patch.object(expert, "_free_full_params", _free_full_params_with_delay):
return self.module(x)
return self.module(x)
def run_backward(self, loss): def run_backward(self, loss):
loss.backward() loss.backward()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment