Unverified Commit 0cbf3bab authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[perf] Further improve performance for FSDP.no_sync (#502)

parent aa9129a3
......@@ -901,11 +901,11 @@ class FullyShardedDataParallel(nn.Module):
if param.grad.requires_grad:
raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require grad")
if not self._is_root or self._require_backward_grad_sync:
if self._require_backward_grad_sync or self.reshard_after_forward:
# Free full params. As a special case, we don't free the full params
# on the root instance when in a ``no_sync`` context (as indicated
# by ``self._require_backward_grad_sync``), since we will need the
# params again immediately for the next forward.
# when in a ``no_sync`` context (as inversely indicated by
# ``self._require_backward_grad_sync``), since the params will not
# get updated before the next forward.
self._free_full_params([param])
if self.mixed_precision:
......
......@@ -4,15 +4,17 @@
# LICENSE file in the root directory of this source tree.
import functools
import itertools
import unittest
from unittest.mock import patch
from parameterized import parameterized
import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.utils.testing import DummyProcessGroup, objects_are_equal
from .test_fsdp import DistributedTest, NestedWrappedModule, spawn_and_init
from .test_fsdp import DistributedTest, NestedWrappedModule, rename_test, spawn_and_init
class TestNoSync(DistributedTest):
......@@ -94,21 +96,60 @@ class TestNoSync(DistributedTest):
assert objects_are_equal(ref_grads, accumulated_grads, raise_exception=True)
keys = ["reshard_after_forward", "mixed_precision"]
COMM_CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([True, False], repeat=len(keys))]
class TestNoSyncCommunication(DistributedTest):
def test_communication(self):
config = {"mixed_precision": True}
@parameterized.expand(COMM_CONFIG_OPTIONS, name_func=rename_test)
def test_communication(self, config):
fn = functools.partial(self._test_communication, config=config)
spawn_and_init(fn)
@parameterized.expand(COMM_CONFIG_OPTIONS, name_func=rename_test)
def test_communication_nested(self, config):
fn = functools.partial(self._test_communication, config=config, nested_model=True)
spawn_and_init(fn)
@classmethod
def _test_communication(self, rank, group, config):
def _test_communication(self, rank, group, config, nested_model=False):
if group.size() == 1:
return
# Turn off bucketing to accurately count number of reduce_scatters.
config["bucket_cap_mb"] = 0
if nested_model:
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
else:
model = self.get_wrapped_model(group, config=config)
batch = model.module.get_input(torch.device("cuda"))
num_fsdp = 0
for child in model.modules(): # includes self
if isinstance(child, FullyShardedDataParallel) and len(child.params) > 0:
num_fsdp += 1
if config.get("reshard_after_forward", True):
# inside no_sync:
# num_fsdp all-gathers in the forward
# num_fsdp-1 all-gathers in the backward (except root)
# outside no_sync:
# num_fsdp-1 all-gathers in the forward (except root)
# num_fsdp-1 all-gathers in the backward (except root)
expected_all_gather1 = 2 * num_fsdp - 1
expected_all_gather2 = expected_all_gather1 + (2 * num_fsdp - 2)
else:
# inside no_sync:
# num_fsdp all-gathers in the forward
# outside no_sync:
# none
expected_all_gather1 = num_fsdp
expected_all_gather2 = num_fsdp
expected_reduce_scatter = num_fsdp
batch = model.module.get_input(torch.device("cuda"))
with patch("torch.distributed.all_gather") as mock_all_gather:
with patch("torch.distributed.reduce_scatter") as mock_reduce_scatter:
with model.no_sync():
......@@ -116,15 +157,21 @@ class TestNoSyncCommunication(DistributedTest):
loss = model.module.get_loss(batch, output)
loss.backward()
assert mock_all_gather.call_count == 1
assert mock_reduce_scatter.call_count == 0
assert (
mock_all_gather.call_count == expected_all_gather1
), f"{mock_all_gather.call_count} != {expected_all_gather1}"
assert mock_reduce_scatter.call_count == 0, f"{mock_reduce_scatter.call_count} != 0"
output = model(*batch)
loss = model.module.get_loss(batch, output)
loss.backward()
assert mock_all_gather.call_count == 1
assert mock_reduce_scatter.call_count == 1
assert (
mock_all_gather.call_count == expected_all_gather2
), f"{mock_all_gather.call_count} != {expected_all_gather2}"
assert (
mock_reduce_scatter.call_count == expected_reduce_scatter
), f"{mock_reduce_scatter.call_count} != {expected_reduce_scatter}"
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment