Unverified Commit a1612d79 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix]: let FSDP handle model with multiple forward pass and checkpoint (#621)



* [fix]: let FSDP handle model with multiple forward pass and checkpoint

* try CI again

* save

* save

* fixed case with bn

* minor

* add the new file

* minor

* added test of a single case, runtime is about 50s

* enable all 8 test cases

* cleanup

* cleanup

* skip flatten case with 1.6 and 1.7

* minor
Co-authored-by: default avatarMin Xu <min.xu@acm.org>
parent 5cddaea4
......@@ -115,9 +115,15 @@ class FullyShardedDataParallel(nn.Module):
an assert on the backward pass. The solution is to leave some parameters
to the outer FSDP.
.. warning::
If activation checkpointing is used with FSDP, it is strongly encouraged
to use ``checkpoint_wrapper`` function from FairScale instead of the
``checkpoint`` function from PyTorch.
Args:
module (nn.Module):
module to checkpoint
module to be wrapped with FullyShardedDataParallel.
process_group (Optional):
process group for sharding
reshard_after_forward (bool, Optional):
......@@ -207,7 +213,7 @@ class FullyShardedDataParallel(nn.Module):
self.no_broadcast_optim_state = no_broadcast_optim_state
self.state_dict_device = state_dict_device or self.compute_device
self.gradient_predivide_factor: int = self.get_gradient_predivide_factor(self.world_size)
self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
self.numel_padded_per_param: List[int] = []
......@@ -275,11 +281,31 @@ class FullyShardedDataParallel(nn.Module):
f"FSDP.__init__(done): total_init_time: {(init_end - init_start): .4f} num_params: {(sum(p.numel() for p in self.params))}"
)
def get_gradient_predivide_factor(self, world_size: int) -> int:
factor = 1
# Flag to guard multiple pre-forward hook being executed per iteration.
# This is reset at the end of the backward pass.
self._pre_backward_hook_has_run = False
def _get_gradient_predivide_factor(self, world_size: int) -> float:
factor: int = 1
while world_size % factor == 0 and world_size / factor > factor:
factor = factor * 2
return factor
factor *= 2
return float(factor)
def set_gradient_divide_factors(self, pre: float, post: float, recursive: bool) -> None:
"""Allowing user to override the pre and post divide factors.
Args:
pre (float): divide factor before the reduction.
post (float): divide factor after the reduction.
recursive (bool): recursively set it for all child FSDP instances or not.
"""
self.assert_state(TrainingState.IDLE)
if recursive:
for module in self.modules():
if isinstance(module, FullyShardedDataParallel) and module != self:
module.set_gradient_divide_factors(pre, post, False)
self.gradient_predivide_factor = pre
self.gradient_postdivide_factor = post
@property
def module(self) -> nn.Module:
......@@ -943,7 +969,13 @@ class FullyShardedDataParallel(nn.Module):
self._use_fp32_param_shard()
# Register pre-backward hooks to all-gather the params for the backward
# pass (if needed).
# pass (if output's grad was needed). This won't register anything if
# we are in eval mode.
#
# Some model does forward pass multiple times, we need to register the
# pre-backward hook on every output since the last output's hook has to
# fire first to setup for backward. However, we use ``self._pre_backward_hook_has_run``
# to prevent repeated overhead from multiple hook callbacks.
outputs = self._register_pre_backward_hooks(outputs)
# Done with a forward pass.
......@@ -953,16 +985,18 @@ class FullyShardedDataParallel(nn.Module):
def _register_pre_backward_hooks(self, outputs: Any) -> Any:
"""Register pre-backward hook to run before the wrapped module's
backward. Hooks should be attached to all outputs from the forward."""
backward. Hooks should be attached to all outputs from the forward.
Returns:
outputs: new outputs with hooks registered if they requires gradient.
"""
if not torch.is_grad_enabled():
return outputs # don't register hooks if grad isn't enabled
pre_backward_hook_has_run = [False]
def _pre_backward_hook(*unused: Any) -> None:
if pre_backward_hook_has_run[0]:
return # only run once
pre_backward_hook_has_run[0] = True
if self._pre_backward_hook_has_run:
return # only run once (from multiple outputs or multiple forward passes)
self._pre_backward_hook_has_run = True
# Start of a backward pass.
self.assert_state([TrainingState.IDLE, TrainingState.BACKWARD_PRE])
......@@ -1062,13 +1096,27 @@ class FullyShardedDataParallel(nn.Module):
the local optimizer only sees the relevant parameter shard.
"""
# First hook callback will see PRE state. If we have multiple params,
# then subsequent hook callbacks will see POST state.
# then subsequent hook callbacks will see POST state. When checkpoint
# fwd counter is used, IDLE is also possible since the pre-backward hook
# is not triggered (see ``auto_wrap_bn`` below, we have to use
# FSDP(checkpoint(conv, FSDP(bn), ...)), with reshard_after_forward=False).
if hasattr(self, "_checkpoint_fwd_counter"):
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST, TrainingState.IDLE])
else:
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
self.training_state = TrainingState.BACKWARD_POST
if param.grad is None:
return
if param.grad.requires_grad:
raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require grad")
raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require gradients")
# If this is a checkpointed module, we check if the following
# counter reaches 0. If not, it is not the final backward call
# for this module yet. Therefore, we early return in that case.
if hasattr(self._fsdp_wrapped_module, "_checkpoint_fwd_counter"):
if self._fsdp_wrapped_module._checkpoint_fwd_counter != 0:
return
if self._require_backward_grad_sync or self.reshard_after_forward:
# Free full params. As a special case, we don't free the full params
......@@ -1200,6 +1248,7 @@ class FullyShardedDataParallel(nn.Module):
for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel):
_remove_shard_bwd_hook(m)
m._pre_backward_hook_has_run = False
if m._has_params:
if any(p.requires_grad for p in m.params):
m.assert_state(TrainingState.BACKWARD_POST)
......@@ -1395,8 +1444,8 @@ class FullyShardedDataParallel(nn.Module):
# In case we are failing in the context of autograd hook, asserting
# may not generate useful msg. So, let's print it to be sure.
if self.rank == 0:
print(self)
print(msg)
print(f"Asserting FSDP instance is: {self}")
print(f"ERROR: {msg}")
traceback.print_stack()
raise ValueError(msg)
......@@ -1543,7 +1592,7 @@ class FullyShardedDataParallel(nn.Module):
v_shard = v[0] if self.rank >= len(v) else v[self.rank]
assert ou.is_singleton_tensor(v_shard)
else:
v_shard = v # dont shard entries that are not tensors
v_shard = v # don't shard entries that are not tensors
full_optim_state_dict["state"][id][k] = v_shard
return full_optim_state_dict
......@@ -1686,6 +1735,10 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group:
"process_group": pg,
"mixed_precision": False, # Keep the weights in FP32.
"flatten_parameters": False, # Do not flatten.
# Reshard==False is good for performance. When FSDP(checkpoint(FSDP(bn))) is used, this
# **must** be False because BN's FSDP wrapper's pre-backward callback isn't called
# within the checkpoint's outer backward when multiple forward passes are used.
"reshard_after_forward": False,
}
with enable_wrap(wrap_bn_only_policy, **fsdp_config):
......
......@@ -15,10 +15,12 @@ import torch.utils.checkpoint as torch_checkpoint
from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors
from .misc import patch_batchnorm
from .misc import dec_counter, inc_counter, init_counter, patch_batchnorm
def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False) -> nn.Module:
def checkpoint_wrapper(
module: nn.Module, offload_to_cpu: bool = False, maintain_forward_counter: bool = False
) -> nn.Module:
"""
A friendlier wrapper for performing activation checkpointing.
......@@ -58,16 +60,23 @@ def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False) -> nn.Mo
Args:
module (nn.Module):
The module to be wrapped
offload_to_cpu (Optional, bool):
offload_to_cpu (bool):
Whether to offload activations to CPU.
maintain_forward_counter (bool):
If True, maintain a forward counter per inner module. The counter will first
increases in forward calls of outer forward pass and then decreases in the
forward calls of outer backward pass. It is used by FullyShardedDataParallel.
Returns:
(nn.Module):
Wrapped module
"""
# Patch the batchnorm layers in case there are any.
# Patch the batchnorm layers in case there are any in this module.
patch_batchnorm(module)
if maintain_forward_counter:
init_counter(module)
# The use of weakref here is to prevent creating a ref cycle: m -> m.forward -> m.
# When such cycle exists, gc won't collect the module when the module is freed.
# That causes GPU memory to be leaked. See the unit test for how we catch that.
......@@ -168,6 +177,8 @@ class CheckpointFunction(torch.autograd.Function):
with torch.no_grad():
unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args)
outputs = run_function(*unpacked_args, **unpacked_kwargs)
the_module = unpacked_args[0]
inc_counter(the_module)
if not isinstance(outputs, torch.Tensor):
# Autograd Functions don't like non-Tensor outputs. We can split the
......@@ -200,6 +211,8 @@ class CheckpointFunction(torch.autograd.Function):
unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs)
outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs)
tensor_outputs, _ = split_non_tensors(outputs)
the_module = unpacked_args[0]
dec_counter(the_module)
# Set the states back to what it was at the start of this function.
set_rng_state(bwd_rng_state)
......
......@@ -27,7 +27,6 @@ def patch_batchnorm(module: nn.Module) -> List:
(list):
A list of hook handles, late can be freed.
"""
hooks = []
def pre_forward(module: _BatchNorm, input: Tensor) -> None:
if torch.is_grad_enabled():
......@@ -40,6 +39,7 @@ def patch_batchnorm(module: nn.Module) -> List:
return
module.track_running_stats = module._track_running_stats_backup
hooks = []
for name, child in module.named_modules():
# _BatchNorm is base for bn1d, bn2d, bn3d and sync_bn, apex_sync_bn, etc.
if isinstance(child, _BatchNorm):
......@@ -48,3 +48,28 @@ def patch_batchnorm(module: nn.Module) -> List:
post_handle = child.register_forward_hook(post_forward)
hooks += [pre_handle, post_handle]
return hooks
def init_counter(module: nn.Module) -> None:
"""Add a checkpoint forward pass counter to a module and all its child FSDP modules.
``inc_counter`` and ``dec_counter`` are used together with this to maintain counters
for FSDP to use in case of multiple forward pass and checkpoint being used at the same time.
"""
for mod in module.modules():
mod._checkpoint_fwd_counter = 0
def _add_counter(module: nn.Module, value: int) -> None:
if not hasattr(module, "_checkpoint_fwd_counter"):
return
for mod in module.modules():
mod._checkpoint_fwd_counter += value
def inc_counter(module: nn.Module) -> None:
_add_counter(module, 1)
def dec_counter(module: nn.Module) -> None:
_add_counter(module, -1)
......@@ -26,6 +26,7 @@ you see fit, but refrain from ad-hoc test utils within the different feature set
relative imports.
"""
import contextlib
import functools
import inspect
import logging
......@@ -35,7 +36,7 @@ import random
import subprocess
import sys
import tempfile
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Tuple, Union
import numpy
import pytest
......@@ -645,3 +646,15 @@ def rmf(filename: str) -> None:
os.remove(filename)
except FileNotFoundError:
pass
@contextlib.contextmanager
def temp_files_ctx(num: int) -> Generator:
""" A context to get tempfiles and ensure they are cleaned up. """
files = [tempfile.mkstemp()[1] for _ in range(num)]
yield tuple(files)
# temp files could have been removed, so we use rmf.
for name in files:
rmf(name)
......@@ -108,6 +108,8 @@ class Module(Generic[T_co]):
def extra_repr(self) -> str: ...
#MODIFIED BY TORCHGPIPE
# This is added by checkpoint_wrapper
_checkpoint_fwd_counter: int
# This is added torchgpipe
training: bool
#END
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp_freezing_weights.py
tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py
tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
tests/nn/data_parallel/test_fsdp_freezing_weights.py
tests/nn/data_parallel/test_fsdp.py
tests/nn/misc/test_flatten_params_wrapper.py
......@@ -118,7 +118,7 @@ def temp_files():
@skip_if_single_gpu
def tests1(temp_files):
def test_freezing_weights(temp_files):
world_size = 2
# DDP
fsdp = False
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test FSDP with multiple forward pass + checkpoint. """
import contextlib
import pickle
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim
from fairscale.nn import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import auto_wrap_bn
from fairscale.nn.wrap import enable_wrap, wrap
from fairscale.utils.testing import (
dist_init,
objects_are_equal,
skip_if_single_gpu,
teardown,
temp_files_ctx,
torch_version,
)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.block1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3), nn.BatchNorm2d(64), nn.ReLU(inplace=True),)
self.block2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(),
)
self.head = nn.Linear(128, 10)
def forward(self, x):
if isinstance(x, torch.Tensor):
return self.head(self.block2(self.block1(x)))
elif isinstance(x, list):
ys = [self.head(self.block2(self.block1(e))) for e in x]
return torch.cat(ys, dim=0)
def create_model(with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter):
model = Model()
if with_fsdp:
if wrap_bn:
model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False)
model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False)
if with_checkpoint:
model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=True)
with enable_wrap(
wrapper_cls=FSDP,
flatten_parameters=flatten,
mixed_precision=mixed_precision,
compute_dtype=torch.float32,
fp32_reduce_scatter=fp32_reduce_scatter,
):
model.block1 = wrap(model.block1)
model.block2 = wrap(model.block2)
model.head = wrap(model.head)
else:
if with_checkpoint:
model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=False)
return model
def _distributed_worker(
gpu_id, world_size, with_fsdp, with_checkpoint, files, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter
):
filename, filename_rpc = files[:2]
filename_loss = files[2:]
torch.cuda.set_device(gpu_id)
rank = gpu_id
result = dist_init(rank, world_size, filename, filename_rpc)
assert result, "Dist init failed"
# use False below to debug since error msg is not as good with cudnn.
torch.backends.cudnn.enabled = True
# these make things deterministic.
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Ensure we have multiple forward passes.
batch = [
torch.randn(size=(2, 3, 224, 224)).cuda(),
torch.randn(size=(2, 3, 96, 96)).cuda(),
torch.randn(size=(2, 3, 96, 96)).cuda(),
]
if mixed_precision and not with_fsdp:
batch = [x.half() for x in batch]
model = create_model(with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter)
model = model.cuda()
if with_fsdp:
model = FSDP(
model,
flatten_parameters=flatten,
mixed_precision=mixed_precision,
compute_dtype=torch.float32,
fp32_reduce_scatter=fp32_reduce_scatter,
)
model.set_gradient_divide_factors(1.0, 2.0, True)
no_sync_context = contextlib.suppress()
else:
# With DDP, we need no_sync and manual gradient reduction below because
# it can't handle multiple forward pass + checkpointing otherwise.
model = DistributedDataParallel(model, device_ids=[gpu_id])
no_sync_context = model.no_sync()
mp_context = contextlib.suppress()
if mixed_precision:
mp_context = torch.cuda.amp.autocast(enabled=True)
if gpu_id == 0:
print(model)
target = torch.LongTensor([0, 1, 2, 3, 4, 5]).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
losses = {}
i = 0
with no_sync_context:
for iteration in range(3):
with mp_context:
out = model(batch)
loss = criterion(out, target)
print("Loss", iteration, ":", loss.item())
losses[f"iter_{i}"] = loss
i += 1
optimizer.zero_grad()
loss.backward()
# Manual grad reduction, no autocast.
if not with_fsdp:
for p in model.parameters():
dist.all_reduce(p.grad.data)
p.grad.data.div_(2.0)
# Stepping, no autocast
optimizer.step()
# Due to dist.all_reduce code block above with ddp.no_sync, we seem to hit a bug
# in DDP where tensor.cpu() and torch.save() calls both hang. FSDP is not affected.
# Therefore, we have to compare losses here instead of states.
with open(filename_loss[rank], "wb") as f:
pickle.dump(losses, f)
teardown()
@skip_if_single_gpu
@pytest.mark.parametrize("precision", ["full", "mixed"])
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
@pytest.mark.parametrize("wrap_bn", ["auto_wrap_bn", "no_auto_wrap_bn"])
def test_multiple_forward_checkpoint(precision, flatten, wrap_bn):
mixed_precision = precision == "mixed"
flatten = flatten == "flatten"
wrap_bn = wrap_bn == "auto_wrap_bn"
fp32_reduce_scatter = True if mixed_precision else None
if torch_version() < (1, 8, 0) and flatten:
# 1.6 and 1.7 throws this error:
# RuntimeError: Trying to backward through the graph a second time, but the saved
# intermediate results have already been freed. Specify retain_graph=True when calling
# backward the first time.
pytest.skip("older pytorch throws error when flatten is used")
world_size = 2
expected_losses = None
# Ensure ddp == ddp+ckpt == fsdp == fsdp+ckpt.
for with_fsdp in [False, True]:
for with_checkpoint in [False, True]:
# Get 4 files: 2 for dist_init and 2 for each rank to save the losses.
with temp_files_ctx(num=2 + world_size) as temp_files:
mp.spawn(
_distributed_worker,
(
world_size,
with_fsdp,
with_checkpoint,
temp_files,
mixed_precision,
flatten,
wrap_bn,
fp32_reduce_scatter,
),
nprocs=world_size,
)
final_losses = {}
for rank in range(world_size):
with open(temp_files[2 + rank], "rb") as f:
final_losses[f"rank_{rank}"] = pickle.load(f)
if expected_losses is None:
expected_losses = final_losses
else:
print(f"fsdp: {with_fsdp} ckpt: {with_checkpoint}")
assert objects_are_equal(expected_losses, final_losses, raise_exception=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment