Unverified Commit 4fa2ab9b authored by Darryl Barnhart's avatar Darryl Barnhart Committed by GitHub
Browse files

[fix] FSDP intra-backwards gradient accumulation. (#784)

* [fix] FSDP intra-backwards gradient accumulation.

Ensure gradient reduction accumulates into the unsharded gradient tensor
within a backwards pass. This matters when an FSDP module is called
multiple times within a forward pass, and reduction is _not_ deferred
using activation checkpoint forward counters, bucketing or some other
mechanism.

Closes #780

* [refactor] Remove forward counters. Comments.

Removed forward counters from the activation checkpointing utility, now
that FSDP does not require them for correct operation. Add more detailed
comment about memory usage behaviour with gradient reduction.

* [refactor] Delete deprecated forward counter usage.

* [refactor] Add state assertion as end of pre-backward hook.
parent 482944d9
...@@ -16,7 +16,7 @@ import torch.utils.checkpoint as torch_checkpoint ...@@ -16,7 +16,7 @@ import torch.utils.checkpoint as torch_checkpoint
from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors
from .checkpoint_utils import dec_counter, inc_counter, init_counter, patch_batchnorm from .checkpoint_utils import patch_batchnorm
# https://docs.python.org/3/library/threading.html#thread-local-data # https://docs.python.org/3/library/threading.html#thread-local-data
...@@ -95,9 +95,7 @@ def is_recomputing() -> bool: ...@@ -95,9 +95,7 @@ def is_recomputing() -> bool:
return thread_local.is_recomputing return thread_local.is_recomputing
def checkpoint_wrapper( def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False,) -> nn.Module:
module: nn.Module, offload_to_cpu: bool = False, maintain_forward_counter: bool = False
) -> nn.Module:
""" """
A friendlier wrapper for performing activation checkpointing. A friendlier wrapper for performing activation checkpointing.
...@@ -139,10 +137,6 @@ def checkpoint_wrapper( ...@@ -139,10 +137,6 @@ def checkpoint_wrapper(
The module to be wrapped The module to be wrapped
offload_to_cpu (bool): offload_to_cpu (bool):
Whether to offload activations to CPU. Whether to offload activations to CPU.
maintain_forward_counter (bool):
If True, maintain a forward counter per inner module. The counter will first
increases in forward calls of outer forward pass and then decreases in the
forward calls of outer backward pass. It is used by FullyShardedDataParallel.
Returns: Returns:
(nn.Module): (nn.Module):
...@@ -151,9 +145,6 @@ def checkpoint_wrapper( ...@@ -151,9 +145,6 @@ def checkpoint_wrapper(
# Patch the batchnorm layers in case there are any in this module. # Patch the batchnorm layers in case there are any in this module.
patch_batchnorm(module) patch_batchnorm(module)
if maintain_forward_counter:
init_counter(module)
# The use of weakref here is to prevent creating a ref cycle: m -> m.forward -> m. # The use of weakref here is to prevent creating a ref cycle: m -> m.forward -> m.
# When such cycle exists, gc won't collect the module when the module is freed. # When such cycle exists, gc won't collect the module when the module is freed.
# That causes GPU memory to be leaked. See the unit test for how we catch that. # That causes GPU memory to be leaked. See the unit test for how we catch that.
...@@ -172,10 +163,6 @@ def _checkpointed_forward( ...@@ -172,10 +163,6 @@ def _checkpointed_forward(
module = weak_self() module = weak_self()
# If gradients are disabled, just use original `.forward()` method directly. # If gradients are disabled, just use original `.forward()` method directly.
# Doing so also ensures the internal fwd counter is not incremented in the forward pass,
# which would be an issue during eval since there wouldn't be a corresponding backward pass
# to decrement the fwd counter.
# See https://github.com/facebookresearch/fairscale/pull/709.
if not torch.is_grad_enabled() or thread_local.is_checkpointing_disabled: if not torch.is_grad_enabled() or thread_local.is_checkpointing_disabled:
return original_forward(module, *args, **kwargs) return original_forward(module, *args, **kwargs)
...@@ -290,8 +277,6 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -290,8 +277,6 @@ class CheckpointFunction(torch.autograd.Function):
with torch.no_grad(), enable_checkpointing(): with torch.no_grad(), enable_checkpointing():
unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args)
outputs = run_function(*unpacked_args, **unpacked_kwargs) outputs = run_function(*unpacked_args, **unpacked_kwargs)
the_module = unpacked_args[0]
inc_counter(the_module)
# Because we run with torch.no_grad(), we can't actually access # Because we run with torch.no_grad(), we can't actually access
# outputs.requires_grad. Instead, we manually compute it by # outputs.requires_grad. Instead, we manually compute it by
...@@ -341,8 +326,6 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -341,8 +326,6 @@ class CheckpointFunction(torch.autograd.Function):
unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs) unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs)
outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs)
tensor_outputs, _ = split_non_tensors(outputs) tensor_outputs, _ = split_non_tensors(outputs)
the_module = unpacked_args[0]
dec_counter(the_module)
# Set the states back to what it was at the start of this function. # Set the states back to what it was at the start of this function.
set_rng_state(bwd_rng_state) set_rng_state(bwd_rng_state)
......
...@@ -48,28 +48,3 @@ def patch_batchnorm(module: nn.Module) -> List: ...@@ -48,28 +48,3 @@ def patch_batchnorm(module: nn.Module) -> List:
post_handle = child.register_forward_hook(post_forward) post_handle = child.register_forward_hook(post_forward)
hooks += [pre_handle, post_handle] hooks += [pre_handle, post_handle]
return hooks return hooks
def init_counter(module: nn.Module) -> None:
"""Add a checkpoint forward pass counter to a module and all its child FSDP modules.
``inc_counter`` and ``dec_counter`` are used together with this to maintain counters
for FSDP to use in case of multiple forward pass and checkpoint being used at the same time.
"""
for mod in module.modules():
mod._checkpoint_fwd_counter = 0
def _add_counter(module: nn.Module, value: int) -> None:
if not hasattr(module, "_checkpoint_fwd_counter"):
return
for mod in module.modules():
mod._checkpoint_fwd_counter += value
def inc_counter(module: nn.Module) -> None:
_add_counter(module, 1)
def dec_counter(module: nn.Module) -> None:
_add_counter(module, -1)
...@@ -380,7 +380,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -380,7 +380,7 @@ class FullyShardedDataParallel(nn.Module):
f"FSDP.__init__(done): total_init_time: {(init_end - init_start): .4f} num_params: {(sum(p.numel() for p in self.params))}" f"FSDP.__init__(done): total_init_time: {(init_end - init_start): .4f} num_params: {(sum(p.numel() for p in self.params))}"
) )
# Flag to guard multiple pre-backward hook being executed per iteration. # Flag to guard against preparing gradients multiple times per iteration.
# This is reset at the end of the backward pass. # This is reset at the end of the backward pass.
self._pre_backward_hook_has_run = False self._pre_backward_hook_has_run = False
...@@ -1050,11 +1050,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1050,11 +1050,6 @@ class FullyShardedDataParallel(nn.Module):
# For children instances, if they are checkpointed, state will not be reset to # For children instances, if they are checkpointed, state will not be reset to
# IDLE after each inner forward/backward. # IDLE after each inner forward/backward.
self.assert_state(TrainingState.IDLE) self.assert_state(TrainingState.IDLE)
# Check if the root instance is being checkpointed. It doesn't make sense to
# checkpoint the root instance since it won't save GPU memory.
assert (
getattr(self, "_checkpoint_fwd_counter", 0) == 0
), "Is the root FSDP module wrapping an activation checkpointed module? If so, please remove that."
# As the root, we now set all children instances to False and # As the root, we now set all children instances to False and
# give them a closure to try to queue a wait_for_post_backward. # give them a closure to try to queue a wait_for_post_backward.
self.children_share_process_group = True self.children_share_process_group = True
...@@ -1217,16 +1212,19 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1217,16 +1212,19 @@ class FullyShardedDataParallel(nn.Module):
# Only run the ``self._prep_grads_for_backward`` once per iteration (i.e. in case # Only run the ``self._prep_grads_for_backward`` once per iteration (i.e. in case
# it is multiple outputs or multiple forward passes). # it is multiple outputs or multiple forward passes).
if self._pre_backward_hook_has_run: if not self._pre_backward_hook_has_run:
return self._pre_backward_hook_has_run = True
self._pre_backward_hook_has_run = True # Start of a backward pass for the first time in an iteration.
self.assert_state([TrainingState.IDLE, TrainingState.BACKWARD_PRE])
# Start of a backward pass for the first time in an iteration. # Prepare p.grad so that it is in the right shape, device, accumulated values, etc.
self.assert_state([TrainingState.IDLE, TrainingState.BACKWARD_PRE]) self._prep_grads_for_backward()
self.training_state = TrainingState.BACKWARD_PRE
# Transition to BACKWARD_PRE state if currently IDLE. We can transition from BACKWARD_POST
# Prepare p.grad so that it is in the right shape, device, accumulated values, etc. # to IDLE when FSDP is within activation checkpointing and called multiple times, due to the
self._prep_grads_for_backward() # extra forward pass for re-computation.
if self.training_state == TrainingState.IDLE:
self.training_state = TrainingState.BACKWARD_PRE
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
def _register_hook(t: torch.Tensor) -> torch.Tensor: def _register_hook(t: torch.Tensor) -> torch.Tensor:
if t.requires_grad: if t.requires_grad:
...@@ -1309,14 +1307,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1309,14 +1307,8 @@ class FullyShardedDataParallel(nn.Module):
the local optimizer only sees the relevant parameter shard. the local optimizer only sees the relevant parameter shard.
""" """
# First hook callback will see PRE state. If we have multiple params, # First hook callback will see PRE state. If we have multiple params,
# then subsequent hook callbacks will see POST state. When checkpoint # then subsequent hook callbacks will see POST state.
# fwd counter is used, IDLE is also possible since the pre-backward hook self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
# is not triggered (see ``auto_wrap_bn`` below, we have to use
# FSDP(checkpoint(conv, FSDP(bn), ...)), with reshard_after_forward=False).
if hasattr(self, "_checkpoint_fwd_counter"):
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST, TrainingState.IDLE])
else:
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
self.training_state = TrainingState.BACKWARD_POST self.training_state = TrainingState.BACKWARD_POST
if param.grad is None: if param.grad is None:
return return
...@@ -1324,13 +1316,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1324,13 +1316,6 @@ class FullyShardedDataParallel(nn.Module):
if param.grad.requires_grad: if param.grad.requires_grad:
raise RuntimeError("FSDP only works with gradients that don't require gradients") raise RuntimeError("FSDP only works with gradients that don't require gradients")
# If this is a checkpointed module, we check if the following
# counter reaches 0. If not, it is not the final backward call
# for this module yet. Therefore, we early return in that case.
if hasattr(self._fsdp_wrapped_module, "_checkpoint_fwd_counter"):
if self._fsdp_wrapped_module._checkpoint_fwd_counter != 0:
return
if self._require_backward_grad_sync or self.reshard_after_forward: if self._require_backward_grad_sync or self.reshard_after_forward:
# Free full params. As a special case, we don't free the full params # Free full params. As a special case, we don't free the full params
# when in a ``no_sync`` context (as inversely indicated by # when in a ``no_sync`` context (as inversely indicated by
...@@ -1365,18 +1350,34 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1365,18 +1350,34 @@ class FullyShardedDataParallel(nn.Module):
# Average grad by world_size for consistency with PyTorch DDP. # Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.gradient_predivide_factor) param.grad.data.div_(self.gradient_predivide_factor)
callback_fn = functools.partial(self._post_reduction_hook, param)
if param._is_sharded: if param._is_sharded:
assert param._is_sharded
assert self._reducer is not None assert self._reducer is not None
grad_chunks = chunk_and_pad(param.grad.data, self.world_size) # Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
# param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
# matter, neglecting rounding.
grad = param.grad.data
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
#
# The effect on memory consumption is not usually significant. No extra memory is allocated if this
# module is called only once, reduction happens quickly, or the tensor is bucketed. If the module is
# called multiple times, and the backwards pass runs far enough ahead of the `post_backward` stream,
# then we can end up with multiple unsharded gradients allocated and queued for reduction.
#
# We could guard against this by using CUDA events (see record_event, wait_event in torch.cuda.Stream).
# This ensures the `default` stream will wait for the `post_backward` stream to complete the last
# reduction for this module, before scheduling additional reduction work. Then at most there are two
# unsharded gradients allocated; one for a pending reduction, and one for gradient computation.
param.grad = None
callback_fn = functools.partial(self._post_reduction_hook, param)
grad_chunks = chunk_and_pad(grad, self.world_size)
self._reducer.reduce_scatter_async(grad_chunks, group=self.process_group, callback_fn=callback_fn) self._reducer.reduce_scatter_async(grad_chunks, group=self.process_group, callback_fn=callback_fn)
else: else:
# Currently the only way for _is_sharded to be False is if # Currently the only way for _is_sharded to be False is if
# world_size == 1. This could be relaxed in the future, in which # world_size == 1. This could be relaxed in the future, in which
# case grads should be all-reduced here. # case grads should be all-reduced here.
assert self.world_size == 1 assert self.world_size == 1
callback_fn(param.grad.data) self._post_reduction_hook(param, param.grad.data)
# After _post_backward_hook returns, orig_grad_data will eventually # After _post_backward_hook returns, orig_grad_data will eventually
# go out of scope, at which point it could otherwise be freed for # go out of scope, at which point it could otherwise be freed for
...@@ -1388,33 +1389,36 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1388,33 +1389,36 @@ class FullyShardedDataParallel(nn.Module):
def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> None: def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
"""Hook to call on each param after the reduce-scatter.""" """Hook to call on each param after the reduce-scatter."""
assert torch.cuda.current_stream() == self._streams["post_backward"] assert torch.cuda.current_stream() == self._streams["post_backward"]
assert param.grad is not None
self.assert_state(TrainingState.BACKWARD_POST) self.assert_state(TrainingState.BACKWARD_POST)
param.grad.data = reduced_grad
if self.gradient_postdivide_factor > 1: if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP. # Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.gradient_postdivide_factor) reduced_grad.data.div_(self.gradient_postdivide_factor)
# Cast grad to param's dtype (typically FP32). Note: we do this # Cast grad to param's dtype (typically FP32). Note: we do this
# before the move_grads_to_cpu step so that this entire hook remains # before the move_grads_to_cpu step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case. # non-blocking. The downside is a bit more D2H transfer in that case.
if self.mixed_precision: if self.mixed_precision:
orig_param_grad_data = param.grad.data orig_param_grad_data = reduced_grad.data
param.grad.data = param.grad.data.to(dtype=param.data.dtype) reduced_grad.data = reduced_grad.data.to(dtype=param.data.dtype)
# Don't let this memory get reused until after the transfer. # Don't let this memory get reused until after the transfer.
orig_param_grad_data.record_stream(torch.cuda.current_stream()) orig_param_grad_data.record_stream(torch.cuda.current_stream())
if hasattr(param, "_saved_grad_shard") and param._saved_grad_shard is not None:
assert ( if param._is_sharded:
param._saved_grad_shard.shape == param.grad.shape # Accumulate into the gradient shard.
), f"{param._saved_grad_shard.shape} vs {param.grad.shape}" if getattr(param, "_saved_grad_shard", None) is None:
param.grad.data += param._saved_grad_shard param._saved_grad_shard = reduced_grad.data
delattr(param, "_saved_grad_shard") else:
# Optionally move gradients to CPU, typically used if one is running assert (
# the optimizer on the CPU. param._saved_grad_shard.shape == reduced_grad.shape
), f"{param._saved_grad_shard.shape} vs {reduced_grad.shape}"
param._saved_grad_shard.data += reduced_grad.data
reduced_grad = param._saved_grad_shard.data
# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
# backwards pass completes, we will set `.grad` to the CPU copy.
if self.move_grads_to_cpu: if self.move_grads_to_cpu:
param._cpu_grad.copy_(param.grad.data, non_blocking=True) param._cpu_grad.copy_(reduced_grad.data, non_blocking=True)
# Don't let this memory get reused until after the transfer. # Don't let this memory get reused until after the transfer.
param.grad.data.record_stream(torch.cuda.current_stream()) reduced_grad.data.record_stream(torch.cuda.current_stream())
param.grad.data = param._cpu_grad
def _queue_wait_for_post_backward(self) -> None: def _queue_wait_for_post_backward(self) -> None:
"""Try to queue a `wait_for_post_backward` callback. """Try to queue a `wait_for_post_backward` callback.
...@@ -1457,19 +1461,38 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1457,19 +1461,38 @@ class FullyShardedDataParallel(nn.Module):
if self._reducer is not None: if self._reducer is not None:
self._reducer.teardown() self._reducer.teardown()
def _remove_shard_bwd_hook(fsdp_module: FullyShardedDataParallel) -> None: def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
"""Helper used below on all fsdp modules.""" """Helper used below on all fsdp modules."""
for p in fsdp_module.params: for p in fsdp_module.params:
if p.requires_grad: if not p.requires_grad:
if hasattr(p, "_shard_bwd_hook"): continue
assert len(p._shard_bwd_hook) == 2, len(p._shard_bwd_hook) if hasattr(p, "_shard_bwd_hook"):
p._shard_bwd_hook[1].remove() assert len(p._shard_bwd_hook) == 2, len(p._shard_bwd_hook)
delattr(p, "_shard_bwd_hook") p._shard_bwd_hook[1].remove()
delattr(p, "_shard_bwd_hook")
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and p._saved_grad_shard
# remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and
# sync passes, if desired.
if not self._require_backward_grad_sync:
continue
# Parameter and gradient devices must match.
if hasattr(p, "_cpu_grad"):
assert p.device == torch.device("cpu")
p.grad = p._cpu_grad
elif hasattr(p, "_saved_grad_shard"):
assert p.device == p._saved_grad_shard.device
p.grad = p._saved_grad_shard
if hasattr(p, "_saved_grad_shard"):
delattr(p, "_saved_grad_shard")
# Update root and nested FSDP's hooks and flags. # Update root and nested FSDP's hooks and flags.
for m in self.modules(): # includes self for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel): if isinstance(m, FullyShardedDataParallel):
_remove_shard_bwd_hook(m) _finalize_parameters(m)
m._pre_backward_hook_has_run = False m._pre_backward_hook_has_run = False
if any(p.requires_grad for p in m.parameters()): if any(p.requires_grad for p in m.parameters()):
# Check if the module has params and if any of them has # Check if the module has params and if any of them has
......
...@@ -106,9 +106,6 @@ class Module(Generic[T_co]): ...@@ -106,9 +106,6 @@ class Module(Generic[T_co]):
def extra_repr(self) -> str: ... def extra_repr(self) -> str: ...
# This is added by checkpoint_wrapper
_checkpoint_fwd_counter: int
# This is added torchgpipe # This is added torchgpipe
training: bool training: bool
......
...@@ -106,7 +106,7 @@ def parity3d_checkpoint_syncbn(): ...@@ -106,7 +106,7 @@ def parity3d_checkpoint_syncbn():
x = torch.randn(4, 3, 4, 4, 4).cuda() * rank x = torch.randn(4, 3, 4, 4, 4).cuda() * rank
torch_bn = torch.nn.SyncBatchNorm(3).cuda() torch_bn = torch.nn.SyncBatchNorm(3).cuda()
fs_bn = SyncBatchNorm(3).cuda() fs_bn = SyncBatchNorm(3).cuda()
fs_bn = checkpoint_wrapper(fs_bn, maintain_forward_counter=True) fs_bn = checkpoint_wrapper(fs_bn)
check_parity_ddp(torch_bn, fs_bn, x) check_parity_ddp(torch_bn, fs_bn, x)
......
...@@ -70,7 +70,15 @@ class Model2(nn.Module): ...@@ -70,7 +70,15 @@ class Model2(nn.Module):
def _create_model( def _create_model(
with_model2, with_sync_bn, with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter with_model2,
with_sync_bn,
with_fsdp,
with_checkpoint,
mixed_precision,
flatten,
wrap_bn,
fp32_reduce_scatter,
bucket_cap_mb,
): ):
model = Model2() if with_model2 else Model() model = Model2() if with_model2 else Model()
fsdp_config = None fsdp_config = None
...@@ -84,33 +92,32 @@ def _create_model( ...@@ -84,33 +92,32 @@ def _create_model(
"force_input_to_fp32": True, # SyncBN needs this. "force_input_to_fp32": True, # SyncBN needs this.
} }
if with_fsdp and wrap_bn:
model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False, fsdp_config=fsdp_config)
model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False, fsdp_config=fsdp_config)
if with_model2:
model.block3 = auto_wrap_bn(model.block3, single_rank_pg=False, fsdp_config=fsdp_config)
if with_checkpoint:
model.block2 = checkpoint_wrapper(model.block2)
if with_model2:
model.block3 = checkpoint_wrapper(model.block3)
if with_fsdp: if with_fsdp:
if wrap_bn:
model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False, fsdp_config=fsdp_config)
model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False, fsdp_config=fsdp_config)
if with_model2:
model.block3 = auto_wrap_bn(model.block3, single_rank_pg=False, fsdp_config=fsdp_config)
if with_checkpoint:
model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=True)
if with_model2:
model.block3 = checkpoint_wrapper(model.block3, maintain_forward_counter=True)
with enable_wrap( with enable_wrap(
wrapper_cls=FSDP, wrapper_cls=FSDP,
flatten_parameters=flatten, flatten_parameters=flatten,
mixed_precision=mixed_precision, mixed_precision=mixed_precision,
compute_dtype=torch.float32, compute_dtype=torch.float32,
fp32_reduce_scatter=fp32_reduce_scatter, fp32_reduce_scatter=fp32_reduce_scatter,
bucket_cap_mb=bucket_cap_mb,
): ):
model.block1 = wrap(model.block1) model.block1 = wrap(model.block1)
model.block2 = wrap(model.block2) model.block2 = wrap(model.block2)
if with_model2: if with_model2:
model.block3 = wrap(model.block3) model.block3 = wrap(model.block3)
model.head = wrap(model.head) model.head = wrap(model.head)
else:
if with_checkpoint:
model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=False)
if with_model2:
model.block3 = checkpoint_wrapper(model.block3, maintain_forward_counter=False)
return model return model
...@@ -126,6 +133,7 @@ def _distributed_worker( ...@@ -126,6 +133,7 @@ def _distributed_worker(
flatten, flatten,
wrap_bn, wrap_bn,
fp32_reduce_scatter, fp32_reduce_scatter,
bucket_cap_mb,
): ):
filename, filename_rpc = files[:2] filename, filename_rpc = files[:2]
filename_loss = files[2:] filename_loss = files[2:]
...@@ -155,7 +163,15 @@ def _distributed_worker( ...@@ -155,7 +163,15 @@ def _distributed_worker(
batch = [x.half() for x in batch] batch = [x.half() for x in batch]
model = _create_model( model = _create_model(
with_model2, with_sync_bn, with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter with_model2,
with_sync_bn,
with_fsdp,
with_checkpoint,
mixed_precision,
flatten,
wrap_bn,
fp32_reduce_scatter,
bucket_cap_mb,
) )
model = model.cuda() model = model.cuda()
...@@ -166,6 +182,7 @@ def _distributed_worker( ...@@ -166,6 +182,7 @@ def _distributed_worker(
mixed_precision=mixed_precision, mixed_precision=mixed_precision,
compute_dtype=torch.float32, compute_dtype=torch.float32,
fp32_reduce_scatter=fp32_reduce_scatter, fp32_reduce_scatter=fp32_reduce_scatter,
bucket_cap_mb=bucket_cap_mb,
) )
model.set_gradient_divide_factors(1.0, 2.0, True) model.set_gradient_divide_factors(1.0, 2.0, True)
no_sync_context = contextlib.suppress() no_sync_context = contextlib.suppress()
...@@ -228,6 +245,7 @@ def _get_cached_results( ...@@ -228,6 +245,7 @@ def _get_cached_results(
flatten, flatten,
wrap_bn, wrap_bn,
fp32_reduce_scatter, fp32_reduce_scatter,
bucket_cap_mb,
): ):
""" Cache the training to save time. For DDP, flatten, wrap_bn etc. doesn't matter, so """ Cache the training to save time. For DDP, flatten, wrap_bn etc. doesn't matter, so
the results can be cached. the results can be cached.
...@@ -247,6 +265,7 @@ def _get_cached_results( ...@@ -247,6 +265,7 @@ def _get_cached_results(
flatten, flatten,
wrap_bn, wrap_bn,
fp32_reduce_scatter, fp32_reduce_scatter,
bucket_cap_mb,
) )
global _result_cache global _result_cache
if key not in _result_cache: if key not in _result_cache:
...@@ -265,6 +284,7 @@ def _get_cached_results( ...@@ -265,6 +284,7 @@ def _get_cached_results(
flatten, flatten,
wrap_bn, wrap_bn,
fp32_reduce_scatter, fp32_reduce_scatter,
bucket_cap_mb,
), ),
nprocs=world_size, nprocs=world_size,
) )
...@@ -310,39 +330,66 @@ def test_multiple_forward_checkpoint(precision, flatten, wrap_bn, model_type, bn ...@@ -310,39 +330,66 @@ def test_multiple_forward_checkpoint(precision, flatten, wrap_bn, model_type, bn
world_size = 2 world_size = 2
expected_losses = None expected_losses = None
# Ensure ddp == ddp+ckpt == fsdp == fsdp+ckpt.
# Ensure ddp == fsdp when modules are called multiple times per forward pass with/without checkpointing, forward
# counters and reducer bucketing.
#
# The bucketing check exists because the asynchronous gradient reduction it induces can interact with multiple
# forward passes in complex ways. For example, in the midst of a sharded backward pass, `parameter.grad` may only be
# `None` or an unsharded gradient tensor. The sharded tensor is then set at the end of the backwards pass. But a
# unit test with bucketing enabled might not catch violations of this invariant. For very small models, like the
# kind used in this unit test, bucketing will delay gradient reduction until after all the gradient computation is
# done. If the reduction incorrectly sets `.grad` to the _sharded_ variant, the test might not fail, since the
# gradient computations have already happened. Toggling bucketing helps verify that gradient reduction and
# computation interact correctly.
combinations = []
for with_fsdp in [False, True]: for with_fsdp in [False, True]:
for with_checkpoint in [False, True]: for with_checkpoint in [False, True]:
if not with_fsdp and with_checkpoint: if not with_fsdp and with_checkpoint:
continue continue
final_losses = _get_cached_results( for with_bucketing in [False, True]:
world_size, if not with_fsdp and with_bucketing:
with_model2, continue
with_sync_bn, combinations.append((with_fsdp, with_checkpoint, with_bucketing))
with_fsdp, print("")
with_checkpoint, print("Testing the following configurations:")
mixed_precision, for with_fsdp, with_checkpoint, with_bucketing in combinations:
flatten, print(f" fsdp {with_fsdp} ckpt {with_checkpoint} bucketing {with_bucketing}")
wrap_bn,
fp32_reduce_scatter, for with_fsdp, with_checkpoint, with_bucketing in combinations:
) if with_bucketing:
if expected_losses is None: bucket_cap_mb = 25
expected_losses = final_losses else:
else: bucket_cap_mb = 0
print(f"checking: fsdp {with_fsdp} ckpt {with_checkpoint} with ddp+no_ckpt") final_losses = _get_cached_results(
world_size,
def check(exp, res): with_model2,
assert list(exp.keys()) == list(res.keys()), f"{list(exp.keys())} vs. {list(res.keys())}" with_sync_bn,
rtol = 1e-4 with_fsdp,
atol = 1e-5 with_checkpoint,
if with_model2 and mixed_precision and torch_version() >= (1, 9, 0): mixed_precision,
# On CI, with longer model2, mixed precsion and 1.9, even ddp vs. ddp+ckpt has flatten,
# larger errors. wrap_bn,
rtol = 1e-3 fp32_reduce_scatter,
atol = 1e-4 bucket_cap_mb,
for key in exp.keys(): )
exp_loss = exp[key] if expected_losses is None:
res_loss = res[key] expected_losses = final_losses
torch.testing.assert_allclose(exp_loss, res_loss, rtol=rtol, atol=atol) else:
print(f"checking: fsdp {with_fsdp} ckpt {with_checkpoint} bucketing {with_bucketing} with ddp+no_ckpt")
check(expected_losses, final_losses)
def check(exp, res):
assert list(exp.keys()) == list(res.keys()), f"{list(exp.keys())} vs. {list(res.keys())}"
rtol = 1e-4
atol = 1e-5
if with_model2 and mixed_precision and torch_version() >= (1, 9, 0):
# On CI, with longer model2, mixed precsion and 1.9, even ddp vs. ddp+ckpt has
# larger errors.
rtol = 1e-3
atol = 1e-4
for key in exp.keys():
exp_loss = exp[key]
res_loss = res[key]
torch.testing.assert_allclose(exp_loss, res_loss, rtol=rtol, atol=atol)
check(expected_losses, final_losses)
...@@ -71,12 +71,10 @@ def _train_step(model, optim, expected_param_shapes): ...@@ -71,12 +71,10 @@ def _train_step(model, optim, expected_param_shapes):
# Create input and run forward pass. # Create input and run forward pass.
input = torch.randn(2, 3).cuda() input = torch.randn(2, 3).cuda()
loss = model(input).sum() loss = model(input).sum()
_check_fwd_counter(model, 1)
_check_params(model, expected_param_shapes) _check_params(model, expected_param_shapes)
# Run backward pass. # Run backward pass.
loss.backward() loss.backward()
_check_fwd_counter(model, 0)
_check_params(model, expected_param_shapes) _check_params(model, expected_param_shapes)
# Finally, take a step. # Finally, take a step.
...@@ -90,7 +88,6 @@ def _eval_step(model, optim, expected_param_shapes): ...@@ -90,7 +88,6 @@ def _eval_step(model, optim, expected_param_shapes):
with torch.no_grad(): with torch.no_grad():
input = torch.randn(2, 3).cuda() input = torch.randn(2, 3).cuda()
model(input).sum() model(input).sum()
_check_fwd_counter(model, 0)
_check_params(model, expected_param_shapes) _check_params(model, expected_param_shapes)
...@@ -104,20 +101,11 @@ def _check_params(model, expected_param_shapes): ...@@ -104,20 +101,11 @@ def _check_params(model, expected_param_shapes):
), f"Parameter {key} should have shape {expected_shape}, but found shape {current_shape}" ), f"Parameter {key} should have shape {expected_shape}, but found shape {current_shape}"
def _check_fwd_counter(model, expected_value):
current_value = model._fpw_module.ffn[1]._fsdp_wrapped_module.module._checkpoint_fwd_counter
assert (
current_value == expected_value
), f"forward counter of checkpointed submodule should be {expected_value}, but found {current_value}"
class SimpleModuleWithCheckpointing(nn.Module): class SimpleModuleWithCheckpointing(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.ffn = nn.Sequential( self.ffn = nn.Sequential(
nn.Linear(3, 3), nn.Linear(3, 3), FullyShardedDataParallel(checkpoint_wrapper(nn.Linear(3, 3))), nn.Linear(3, 3),
FullyShardedDataParallel(checkpoint_wrapper(nn.Linear(3, 3), maintain_forward_counter=True)),
nn.Linear(3, 3),
) )
def forward(self, x): def forward(self, x):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment