Unverified Commit 0cbf3bab authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[perf] Further improve performance for FSDP.no_sync (#502)

parent aa9129a3
...@@ -901,11 +901,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -901,11 +901,11 @@ class FullyShardedDataParallel(nn.Module):
if param.grad.requires_grad: if param.grad.requires_grad:
raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require grad") raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require grad")
if not self._is_root or self._require_backward_grad_sync: if self._require_backward_grad_sync or self.reshard_after_forward:
# Free full params. As a special case, we don't free the full params # Free full params. As a special case, we don't free the full params
# on the root instance when in a ``no_sync`` context (as indicated # when in a ``no_sync`` context (as inversely indicated by
# by ``self._require_backward_grad_sync``), since we will need the # ``self._require_backward_grad_sync``), since the params will not
# params again immediately for the next forward. # get updated before the next forward.
self._free_full_params([param]) self._free_full_params([param])
if self.mixed_precision: if self.mixed_precision:
......
...@@ -4,15 +4,17 @@ ...@@ -4,15 +4,17 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import functools import functools
import itertools
import unittest import unittest
from unittest.mock import patch from unittest.mock import patch
from parameterized import parameterized
import torch import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.utils.testing import DummyProcessGroup, objects_are_equal from fairscale.utils.testing import DummyProcessGroup, objects_are_equal
from .test_fsdp import DistributedTest, NestedWrappedModule, spawn_and_init from .test_fsdp import DistributedTest, NestedWrappedModule, rename_test, spawn_and_init
class TestNoSync(DistributedTest): class TestNoSync(DistributedTest):
...@@ -94,21 +96,60 @@ class TestNoSync(DistributedTest): ...@@ -94,21 +96,60 @@ class TestNoSync(DistributedTest):
assert objects_are_equal(ref_grads, accumulated_grads, raise_exception=True) assert objects_are_equal(ref_grads, accumulated_grads, raise_exception=True)
keys = ["reshard_after_forward", "mixed_precision"]
COMM_CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([True, False], repeat=len(keys))]
class TestNoSyncCommunication(DistributedTest): class TestNoSyncCommunication(DistributedTest):
def test_communication(self): @parameterized.expand(COMM_CONFIG_OPTIONS, name_func=rename_test)
config = {"mixed_precision": True} def test_communication(self, config):
fn = functools.partial(self._test_communication, config=config) fn = functools.partial(self._test_communication, config=config)
spawn_and_init(fn) spawn_and_init(fn)
@parameterized.expand(COMM_CONFIG_OPTIONS, name_func=rename_test)
def test_communication_nested(self, config):
fn = functools.partial(self._test_communication, config=config, nested_model=True)
spawn_and_init(fn)
@classmethod @classmethod
def _test_communication(self, rank, group, config): def _test_communication(self, rank, group, config, nested_model=False):
if group.size() == 1: if group.size() == 1:
return return
model = self.get_wrapped_model(group, config=config) # Turn off bucketing to accurately count number of reduce_scatters.
config["bucket_cap_mb"] = 0
if nested_model:
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
else:
model = self.get_wrapped_model(group, config=config)
num_fsdp = 0
for child in model.modules(): # includes self
if isinstance(child, FullyShardedDataParallel) and len(child.params) > 0:
num_fsdp += 1
if config.get("reshard_after_forward", True):
# inside no_sync:
# num_fsdp all-gathers in the forward
# num_fsdp-1 all-gathers in the backward (except root)
# outside no_sync:
# num_fsdp-1 all-gathers in the forward (except root)
# num_fsdp-1 all-gathers in the backward (except root)
expected_all_gather1 = 2 * num_fsdp - 1
expected_all_gather2 = expected_all_gather1 + (2 * num_fsdp - 2)
else:
# inside no_sync:
# num_fsdp all-gathers in the forward
# outside no_sync:
# none
expected_all_gather1 = num_fsdp
expected_all_gather2 = num_fsdp
expected_reduce_scatter = num_fsdp
batch = model.module.get_input(torch.device("cuda")) batch = model.module.get_input(torch.device("cuda"))
with patch("torch.distributed.all_gather") as mock_all_gather: with patch("torch.distributed.all_gather") as mock_all_gather:
with patch("torch.distributed.reduce_scatter") as mock_reduce_scatter: with patch("torch.distributed.reduce_scatter") as mock_reduce_scatter:
with model.no_sync(): with model.no_sync():
...@@ -116,15 +157,21 @@ class TestNoSyncCommunication(DistributedTest): ...@@ -116,15 +157,21 @@ class TestNoSyncCommunication(DistributedTest):
loss = model.module.get_loss(batch, output) loss = model.module.get_loss(batch, output)
loss.backward() loss.backward()
assert mock_all_gather.call_count == 1 assert (
assert mock_reduce_scatter.call_count == 0 mock_all_gather.call_count == expected_all_gather1
), f"{mock_all_gather.call_count} != {expected_all_gather1}"
assert mock_reduce_scatter.call_count == 0, f"{mock_reduce_scatter.call_count} != 0"
output = model(*batch) output = model(*batch)
loss = model.module.get_loss(batch, output) loss = model.module.get_loss(batch, output)
loss.backward() loss.backward()
assert mock_all_gather.call_count == 1 assert (
assert mock_reduce_scatter.call_count == 1 mock_all_gather.call_count == expected_all_gather2
), f"{mock_all_gather.call_count} != {expected_all_gather2}"
assert (
mock_reduce_scatter.call_count == expected_reduce_scatter
), f"{mock_reduce_scatter.call_count} != {expected_reduce_scatter}"
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment