Unverified Commit f2af4c66 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] [FSDP]: add experimental support to shared weights (#836)



* added a new test, passing without shared weights

* tested weight sharing

* added the test to test list file

* extended to world_size = 2

* fixed test

* [feat]: add limited and experimental support for shared parameter

* fixed tests

* simplify to work with layer with at least 1 non-shared params and add code to pick up linked_param field for sharding the shared param

* fixed the case where linked param is not in separate FSDP

* changelog and remove old code
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent a9fcaa28
...@@ -10,6 +10,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -10,6 +10,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added ### Added
- LayerwiseMemoryTracker[feature][experimental] - This is a new experimental tool to help track, visualize and suggest fix for memory issues occurring during the forward/backward pass of your models. [#808] - LayerwiseMemoryTracker[feature][experimental] - This is a new experimental tool to help track, visualize and suggest fix for memory issues occurring during the forward/backward pass of your models. [#808]
- [FSDP]: limited support of shared weights between FSDP wrappers. This allows large parameter
and gradient memory to be sharded despite being needed from different layers due to
weight sharing. [#836]
## [0.4.1] - 2021-09-17 ## [0.4.1] - 2021-09-17
### Fixed ### Fixed
......
...@@ -414,6 +414,33 @@ class FullyShardedDataParallel(nn.Module): ...@@ -414,6 +414,33 @@ class FullyShardedDataParallel(nn.Module):
assert isinstance(self._fsdp_wrapped_module, FlattenParamsWrapper) assert isinstance(self._fsdp_wrapped_module, FlattenParamsWrapper)
return self._fsdp_wrapped_module return self._fsdp_wrapped_module
def append_shared_param(self, p: Parameter) -> None:
""" Add a param that's already owned by another FSDP wrapper.
.. warning:: This is experimental!
This only works with all sharing FSDP modules are un-flattened.
p must to be already sharded by the owning module.
Check the corresponding unit test to see how is it used and tested.
In particular, the sharing FSDP wrappers are "siblings" not "parent"
and "child" of each other in the nested module structure.
Args:
p (Parameter):
The shared parameter.
"""
assert self._is_root is None
assert not self.flatten_parameters
assert isinstance(p, Parameter)
assert p._is_sharded
p._is_shared = True
assert (
len(list(filter(lambda p: not (hasattr(p, "_is_shared") and p._is_shared), self.params))) > 0
), "Must have at least 1 non-shared param."
self.params.append(p)
def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel": def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
""" """
Applies ``fn`` recursively to every submodule (as returned by Applies ``fn`` recursively to every submodule (as returned by
...@@ -916,8 +943,16 @@ class FullyShardedDataParallel(nn.Module): ...@@ -916,8 +943,16 @@ class FullyShardedDataParallel(nn.Module):
yield yield
finally: finally:
stack.close() stack.close()
assert len(full_tensors) == len(self.params) non_shared_params = self.params
for p, (full_tensor, safe_to_free) in zip(self.params, full_tensors): # filter out shared params for all but the owner FSDP module.
if len(full_tensors) < len(non_shared_params):
non_shared_params = list(
filter(lambda p: not (hasattr(p, "_is_shared") and p._is_shared), self.params)
)
assert len(full_tensors) == len(
non_shared_params
), f"{len(full_tensors)} vs. {len(non_shared_params)}"
for p, (full_tensor, safe_to_free) in zip(non_shared_params, full_tensors):
if not volatile: if not volatile:
# Copy any changes made to the full params back into # Copy any changes made to the full params back into
# the corresponding local shards. # the corresponding local shards.
...@@ -1367,6 +1402,17 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1367,6 +1402,17 @@ class FullyShardedDataParallel(nn.Module):
if param.grad is None: if param.grad is None:
return return
if hasattr(param, "_linked_param"):
# This links to a shared param. We should finalize the linked param here.
assert param.shape == (1,), param.shape
# If the _is_shared flag is set, then this shared weight is indeed being
# shared between different FSDP wrappers. Otherwise, they are linked but
# likely in the same FSDP wrapper, which means we shouldn't finalize the
# linked param..
if hasattr(param._linked_param, "_is_shared") and param._linked_param._is_shared:
param = param._linked_param
assert param.grad is not None
if param.grad.requires_grad: if param.grad.requires_grad:
raise RuntimeError("FSDP only works with gradients that don't require gradients") raise RuntimeError("FSDP only works with gradients that don't require gradients")
...@@ -1650,6 +1696,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1650,6 +1696,11 @@ class FullyShardedDataParallel(nn.Module):
if not p._is_sharded: # e.g., when world_size == 1 if not p._is_sharded: # e.g., when world_size == 1
update_p_data() update_p_data()
else: else:
# Skip if already built. Only shared param can be rebuilt multiple times.
# A corner case is p._orig_size = (1,), which means the shape equality is
# not a perfect check. But we assume we don't share a param with shape (1,).
if p.data.shape == p._orig_size and hasattr(p, "_is_shared") and p._is_shared:
continue
# If self.move_params_to_cpu and force_full_precision, we need to cast # If self.move_params_to_cpu and force_full_precision, we need to cast
# the FP32 CPU param to CUDA for the all-gather. # the FP32 CPU param to CUDA for the all-gather.
p_data = p.data.to(p._full_param_padded.device, non_blocking=True) p_data = p.data.to(p._full_param_padded.device, non_blocking=True)
...@@ -1704,7 +1755,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1704,7 +1755,7 @@ class FullyShardedDataParallel(nn.Module):
assert p._fp16_shard.storage().size() != 0 assert p._fp16_shard.storage().size() != 0
p.data = p._fp16_shard p.data = p._fp16_shard
else: else:
assert p._full_param_padded.storage().size() != 0 assert p._full_param_padded.storage().size() != 0, f"{p._orig_size} {id(self)}"
p.data = p._full_param_padded[: p._orig_size.numel()].view(p._orig_size) p.data = p._full_param_padded[: p._orig_size.numel()].view(p._orig_size)
@torch.no_grad() @torch.no_grad()
......
...@@ -9,6 +9,7 @@ class Parameter(Tensor): ...@@ -9,6 +9,7 @@ class Parameter(Tensor):
# These are dynamic attributes added by shard_params_data_parallel class. # These are dynamic attributes added by shard_params_data_parallel class.
# Added here for better type checking. # Added here for better type checking.
_is_sharded: bool _is_sharded: bool
_is_shared: bool
_orig_size: Size _orig_size: Size
_cpu_grad: Tensor _cpu_grad: Tensor
_full_param_padded: Tensor _full_param_padded: Tensor
...@@ -16,6 +17,7 @@ class Parameter(Tensor): ...@@ -16,6 +17,7 @@ class Parameter(Tensor):
_fp16_shard: Optional[Tensor] _fp16_shard: Optional[Tensor]
_shard_bwd_hook: Tuple[Any, Any] _shard_bwd_hook: Tuple[Any, Any]
_saved_grad_shard: Tensor _saved_grad_shard: Tensor
_linked_param: Parameter
def __new__(cls, data: Tensor, requires_grad: builtins.bool = True): ... def __new__(cls, data: Tensor, requires_grad: builtins.bool = True): ...
......
tests/nn/data_parallel/test_fsdp_shared_weights.py
tests/nn/data_parallel/test_fsdp_pre_backward_hook.py tests/nn/data_parallel/test_fsdp_pre_backward_hook.py
tests/nn/data_parallel/test_fsdp_overlap.py tests/nn/data_parallel/test_fsdp_overlap.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py tests/nn/data_parallel/test_fsdp_multiple_forward.py
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test FSDP with shared weights between wrappers. """
from copy import deepcopy
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn import Linear, Module
from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import dist_init, objects_are_equal, skip_if_single_gpu, teardown, temp_files_ctx
class Model(Module):
def __init__(self, with_fsdp=False, inner_flat=False, sharing=None):
super().__init__()
self.l0 = Linear(4, 4, bias=True).cuda()
self.l1 = Linear(4, 4, bias=True).cuda()
self.l2 = Linear(4, 4, bias=True).cuda()
self.l3 = Linear(4, 4, bias=True).cuda()
# share the weights. the layer must have at least 1 param is that's not
# shared. Therefore, we have bias=True and testing either sharing the
# weight or the bias.
if sharing == "share_only_weights":
self.l1.weight = self.l3.weight
elif sharing == "share_only_bias":
self.l1.bias = self.l3.bias
else:
assert sharing is None or sharing == "share_none"
if with_fsdp:
# Shared layers much be un-flatten.
self.l1 = FSDP(self.l1, flatten_parameters=False)
self.l2 = FSDP(self.l2, flatten_parameters=inner_flat)
self.l3 = FSDP(self.l3, flatten_parameters=False)
if sharing in ["share_only_weights"]:
self.l3.append_shared_param(self.l1.module.weight)
if sharing in ["share_only_bias"]:
self.l3.append_shared_param(self.l1.module.bias)
def forward(self, x):
x = self.l0(x)
x = self.l1(x)
x = self.l2(x)
x = self.l3(x)
return x
# A fixture to get tempfiles and ensure they are cleaned up.
@pytest.fixture()
def temp_files():
# dist_init needs 2 files + 3 files for before state, after state, in_data.
with temp_files_ctx(5) as files:
yield files
@skip_if_single_gpu
@pytest.mark.parametrize("outer_flat", ["outer_flat", "outer_nonflat"])
@pytest.mark.parametrize("inner_flat", ["inner_flat", "inner_nonflat"])
@pytest.mark.parametrize("sharing", ["share_none", "share_only_weights", "share_only_bias"])
def test_shared_weight(temp_files, outer_flat, inner_flat, sharing):
"""Test FSDP with a model with shared weights."""
outer_flat = outer_flat == "outer_flat"
inner_flat = inner_flat == "inner_flat"
world_size = 2
# Get reference results.
model = Model(sharing=sharing)
sd_before = deepcopy(model.state_dict())
in_data = torch.rand(1, 4).cuda()
_train(model, in_data, world_size)
sd_after = deepcopy(model.state_dict())
# Before and after state should not be equal.
assert not objects_are_equal(sd_before, sd_after)
# Save data
torch.save(sd_before, temp_files[2])
torch.save(sd_after, temp_files[3])
torch.save(in_data, temp_files[4])
# Run FSDP
mp.spawn(
_dist_worker, (world_size, temp_files, outer_flat, inner_flat, sharing), nprocs=world_size,
)
def _dist_worker(rank, world_size, files, outer_flat, inner_flat, sharing):
# Get data from files.
file1, file2, sd_before, sd_after, in_data = files
sd_before = torch.load(sd_before, map_location=lambda storage, loc: storage.cuda(rank))
sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank))
in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank))
result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2)
assert result, "Dist init failed"
fsdp_model = FSDP(Model(with_fsdp=True, inner_flat=inner_flat, sharing=sharing), flatten_parameters=outer_flat)
fsdp_model.load_state_dict(sd_before)
_train(fsdp_model, in_data)
objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True)
teardown()
def _train(model, in_data, steps_per_iter=1):
optim = SGD(model.parameters(), lr=0.1)
for _ in range(3):
# Simulate multiple ranks.
for _ in range(steps_per_iter):
out = model(in_data)
out.sum().backward()
# Simulate gradient means between ranks.
if steps_per_iter > 1:
with torch.no_grad():
for p in model.parameters():
p.grad /= steps_per_iter
optim.step()
model.zero_grad(set_to_none=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment