Unverified Commit 4fa2ab9b authored by Darryl Barnhart's avatar Darryl Barnhart Committed by GitHub
Browse files

[fix] FSDP intra-backwards gradient accumulation. (#784)

* [fix] FSDP intra-backwards gradient accumulation.

Ensure gradient reduction accumulates into the unsharded gradient tensor
within a backwards pass. This matters when an FSDP module is called
multiple times within a forward pass, and reduction is _not_ deferred
using activation checkpoint forward counters, bucketing or some other
mechanism.

Closes #780

* [refactor] Remove forward counters. Comments.

Removed forward counters from the activation checkpointing utility, now
that FSDP does not require them for correct operation. Add more detailed
comment about memory usage behaviour with gradient reduction.

* [refactor] Delete deprecated forward counter usage.

* [refactor] Add state assertion as end of pre-backward hook.
parent 482944d9
......@@ -16,7 +16,7 @@ import torch.utils.checkpoint as torch_checkpoint
from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors
from .checkpoint_utils import dec_counter, inc_counter, init_counter, patch_batchnorm
from .checkpoint_utils import patch_batchnorm
# https://docs.python.org/3/library/threading.html#thread-local-data
......@@ -95,9 +95,7 @@ def is_recomputing() -> bool:
return thread_local.is_recomputing
def checkpoint_wrapper(
module: nn.Module, offload_to_cpu: bool = False, maintain_forward_counter: bool = False
) -> nn.Module:
def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False,) -> nn.Module:
"""
A friendlier wrapper for performing activation checkpointing.
......@@ -139,10 +137,6 @@ def checkpoint_wrapper(
The module to be wrapped
offload_to_cpu (bool):
Whether to offload activations to CPU.
maintain_forward_counter (bool):
If True, maintain a forward counter per inner module. The counter will first
increases in forward calls of outer forward pass and then decreases in the
forward calls of outer backward pass. It is used by FullyShardedDataParallel.
Returns:
(nn.Module):
......@@ -151,9 +145,6 @@ def checkpoint_wrapper(
# Patch the batchnorm layers in case there are any in this module.
patch_batchnorm(module)
if maintain_forward_counter:
init_counter(module)
# The use of weakref here is to prevent creating a ref cycle: m -> m.forward -> m.
# When such cycle exists, gc won't collect the module when the module is freed.
# That causes GPU memory to be leaked. See the unit test for how we catch that.
......@@ -172,10 +163,6 @@ def _checkpointed_forward(
module = weak_self()
# If gradients are disabled, just use original `.forward()` method directly.
# Doing so also ensures the internal fwd counter is not incremented in the forward pass,
# which would be an issue during eval since there wouldn't be a corresponding backward pass
# to decrement the fwd counter.
# See https://github.com/facebookresearch/fairscale/pull/709.
if not torch.is_grad_enabled() or thread_local.is_checkpointing_disabled:
return original_forward(module, *args, **kwargs)
......@@ -290,8 +277,6 @@ class CheckpointFunction(torch.autograd.Function):
with torch.no_grad(), enable_checkpointing():
unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args)
outputs = run_function(*unpacked_args, **unpacked_kwargs)
the_module = unpacked_args[0]
inc_counter(the_module)
# Because we run with torch.no_grad(), we can't actually access
# outputs.requires_grad. Instead, we manually compute it by
......@@ -341,8 +326,6 @@ class CheckpointFunction(torch.autograd.Function):
unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs)
outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs)
tensor_outputs, _ = split_non_tensors(outputs)
the_module = unpacked_args[0]
dec_counter(the_module)
# Set the states back to what it was at the start of this function.
set_rng_state(bwd_rng_state)
......
......@@ -48,28 +48,3 @@ def patch_batchnorm(module: nn.Module) -> List:
post_handle = child.register_forward_hook(post_forward)
hooks += [pre_handle, post_handle]
return hooks
def init_counter(module: nn.Module) -> None:
"""Add a checkpoint forward pass counter to a module and all its child FSDP modules.
``inc_counter`` and ``dec_counter`` are used together with this to maintain counters
for FSDP to use in case of multiple forward pass and checkpoint being used at the same time.
"""
for mod in module.modules():
mod._checkpoint_fwd_counter = 0
def _add_counter(module: nn.Module, value: int) -> None:
if not hasattr(module, "_checkpoint_fwd_counter"):
return
for mod in module.modules():
mod._checkpoint_fwd_counter += value
def inc_counter(module: nn.Module) -> None:
_add_counter(module, 1)
def dec_counter(module: nn.Module) -> None:
_add_counter(module, -1)
......@@ -380,7 +380,7 @@ class FullyShardedDataParallel(nn.Module):
f"FSDP.__init__(done): total_init_time: {(init_end - init_start): .4f} num_params: {(sum(p.numel() for p in self.params))}"
)
# Flag to guard multiple pre-backward hook being executed per iteration.
# Flag to guard against preparing gradients multiple times per iteration.
# This is reset at the end of the backward pass.
self._pre_backward_hook_has_run = False
......@@ -1050,11 +1050,6 @@ class FullyShardedDataParallel(nn.Module):
# For children instances, if they are checkpointed, state will not be reset to
# IDLE after each inner forward/backward.
self.assert_state(TrainingState.IDLE)
# Check if the root instance is being checkpointed. It doesn't make sense to
# checkpoint the root instance since it won't save GPU memory.
assert (
getattr(self, "_checkpoint_fwd_counter", 0) == 0
), "Is the root FSDP module wrapping an activation checkpointed module? If so, please remove that."
# As the root, we now set all children instances to False and
# give them a closure to try to queue a wait_for_post_backward.
self.children_share_process_group = True
......@@ -1217,16 +1212,19 @@ class FullyShardedDataParallel(nn.Module):
# Only run the ``self._prep_grads_for_backward`` once per iteration (i.e. in case
# it is multiple outputs or multiple forward passes).
if self._pre_backward_hook_has_run:
return
self._pre_backward_hook_has_run = True
# Start of a backward pass for the first time in an iteration.
self.assert_state([TrainingState.IDLE, TrainingState.BACKWARD_PRE])
self.training_state = TrainingState.BACKWARD_PRE
# Prepare p.grad so that it is in the right shape, device, accumulated values, etc.
self._prep_grads_for_backward()
if not self._pre_backward_hook_has_run:
self._pre_backward_hook_has_run = True
# Start of a backward pass for the first time in an iteration.
self.assert_state([TrainingState.IDLE, TrainingState.BACKWARD_PRE])
# Prepare p.grad so that it is in the right shape, device, accumulated values, etc.
self._prep_grads_for_backward()
# Transition to BACKWARD_PRE state if currently IDLE. We can transition from BACKWARD_POST
# to IDLE when FSDP is within activation checkpointing and called multiple times, due to the
# extra forward pass for re-computation.
if self.training_state == TrainingState.IDLE:
self.training_state = TrainingState.BACKWARD_PRE
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
def _register_hook(t: torch.Tensor) -> torch.Tensor:
if t.requires_grad:
......@@ -1309,14 +1307,8 @@ class FullyShardedDataParallel(nn.Module):
the local optimizer only sees the relevant parameter shard.
"""
# First hook callback will see PRE state. If we have multiple params,
# then subsequent hook callbacks will see POST state. When checkpoint
# fwd counter is used, IDLE is also possible since the pre-backward hook
# is not triggered (see ``auto_wrap_bn`` below, we have to use
# FSDP(checkpoint(conv, FSDP(bn), ...)), with reshard_after_forward=False).
if hasattr(self, "_checkpoint_fwd_counter"):
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST, TrainingState.IDLE])
else:
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
# then subsequent hook callbacks will see POST state.
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
self.training_state = TrainingState.BACKWARD_POST
if param.grad is None:
return
......@@ -1324,13 +1316,6 @@ class FullyShardedDataParallel(nn.Module):
if param.grad.requires_grad:
raise RuntimeError("FSDP only works with gradients that don't require gradients")
# If this is a checkpointed module, we check if the following
# counter reaches 0. If not, it is not the final backward call
# for this module yet. Therefore, we early return in that case.
if hasattr(self._fsdp_wrapped_module, "_checkpoint_fwd_counter"):
if self._fsdp_wrapped_module._checkpoint_fwd_counter != 0:
return
if self._require_backward_grad_sync or self.reshard_after_forward:
# Free full params. As a special case, we don't free the full params
# when in a ``no_sync`` context (as inversely indicated by
......@@ -1365,18 +1350,34 @@ class FullyShardedDataParallel(nn.Module):
# Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.gradient_predivide_factor)
callback_fn = functools.partial(self._post_reduction_hook, param)
if param._is_sharded:
assert param._is_sharded
assert self._reducer is not None
grad_chunks = chunk_and_pad(param.grad.data, self.world_size)
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
# param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
# matter, neglecting rounding.
grad = param.grad.data
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
#
# The effect on memory consumption is not usually significant. No extra memory is allocated if this
# module is called only once, reduction happens quickly, or the tensor is bucketed. If the module is
# called multiple times, and the backwards pass runs far enough ahead of the `post_backward` stream,
# then we can end up with multiple unsharded gradients allocated and queued for reduction.
#
# We could guard against this by using CUDA events (see record_event, wait_event in torch.cuda.Stream).
# This ensures the `default` stream will wait for the `post_backward` stream to complete the last
# reduction for this module, before scheduling additional reduction work. Then at most there are two
# unsharded gradients allocated; one for a pending reduction, and one for gradient computation.
param.grad = None
callback_fn = functools.partial(self._post_reduction_hook, param)
grad_chunks = chunk_and_pad(grad, self.world_size)
self._reducer.reduce_scatter_async(grad_chunks, group=self.process_group, callback_fn=callback_fn)
else:
# Currently the only way for _is_sharded to be False is if
# world_size == 1. This could be relaxed in the future, in which
# case grads should be all-reduced here.
assert self.world_size == 1
callback_fn(param.grad.data)
self._post_reduction_hook(param, param.grad.data)
# After _post_backward_hook returns, orig_grad_data will eventually
# go out of scope, at which point it could otherwise be freed for
......@@ -1388,33 +1389,36 @@ class FullyShardedDataParallel(nn.Module):
def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
"""Hook to call on each param after the reduce-scatter."""
assert torch.cuda.current_stream() == self._streams["post_backward"]
assert param.grad is not None
self.assert_state(TrainingState.BACKWARD_POST)
param.grad.data = reduced_grad
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.gradient_postdivide_factor)
reduced_grad.data.div_(self.gradient_postdivide_factor)
# Cast grad to param's dtype (typically FP32). Note: we do this
# before the move_grads_to_cpu step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case.
if self.mixed_precision:
orig_param_grad_data = param.grad.data
param.grad.data = param.grad.data.to(dtype=param.data.dtype)
orig_param_grad_data = reduced_grad.data
reduced_grad.data = reduced_grad.data.to(dtype=param.data.dtype)
# Don't let this memory get reused until after the transfer.
orig_param_grad_data.record_stream(torch.cuda.current_stream())
if hasattr(param, "_saved_grad_shard") and param._saved_grad_shard is not None:
assert (
param._saved_grad_shard.shape == param.grad.shape
), f"{param._saved_grad_shard.shape} vs {param.grad.shape}"
param.grad.data += param._saved_grad_shard
delattr(param, "_saved_grad_shard")
# Optionally move gradients to CPU, typically used if one is running
# the optimizer on the CPU.
if param._is_sharded:
# Accumulate into the gradient shard.
if getattr(param, "_saved_grad_shard", None) is None:
param._saved_grad_shard = reduced_grad.data
else:
assert (
param._saved_grad_shard.shape == reduced_grad.shape
), f"{param._saved_grad_shard.shape} vs {reduced_grad.shape}"
param._saved_grad_shard.data += reduced_grad.data
reduced_grad = param._saved_grad_shard.data
# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
# backwards pass completes, we will set `.grad` to the CPU copy.
if self.move_grads_to_cpu:
param._cpu_grad.copy_(param.grad.data, non_blocking=True)
param._cpu_grad.copy_(reduced_grad.data, non_blocking=True)
# Don't let this memory get reused until after the transfer.
param.grad.data.record_stream(torch.cuda.current_stream())
param.grad.data = param._cpu_grad
reduced_grad.data.record_stream(torch.cuda.current_stream())
def _queue_wait_for_post_backward(self) -> None:
"""Try to queue a `wait_for_post_backward` callback.
......@@ -1457,19 +1461,38 @@ class FullyShardedDataParallel(nn.Module):
if self._reducer is not None:
self._reducer.teardown()
def _remove_shard_bwd_hook(fsdp_module: FullyShardedDataParallel) -> None:
def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
"""Helper used below on all fsdp modules."""
for p in fsdp_module.params:
if p.requires_grad:
if hasattr(p, "_shard_bwd_hook"):
assert len(p._shard_bwd_hook) == 2, len(p._shard_bwd_hook)
p._shard_bwd_hook[1].remove()
delattr(p, "_shard_bwd_hook")
if not p.requires_grad:
continue
if hasattr(p, "_shard_bwd_hook"):
assert len(p._shard_bwd_hook) == 2, len(p._shard_bwd_hook)
p._shard_bwd_hook[1].remove()
delattr(p, "_shard_bwd_hook")
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and p._saved_grad_shard
# remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and
# sync passes, if desired.
if not self._require_backward_grad_sync:
continue
# Parameter and gradient devices must match.
if hasattr(p, "_cpu_grad"):
assert p.device == torch.device("cpu")
p.grad = p._cpu_grad
elif hasattr(p, "_saved_grad_shard"):
assert p.device == p._saved_grad_shard.device
p.grad = p._saved_grad_shard
if hasattr(p, "_saved_grad_shard"):
delattr(p, "_saved_grad_shard")
# Update root and nested FSDP's hooks and flags.
for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel):
_remove_shard_bwd_hook(m)
_finalize_parameters(m)
m._pre_backward_hook_has_run = False
if any(p.requires_grad for p in m.parameters()):
# Check if the module has params and if any of them has
......
......@@ -106,9 +106,6 @@ class Module(Generic[T_co]):
def extra_repr(self) -> str: ...
# This is added by checkpoint_wrapper
_checkpoint_fwd_counter: int
# This is added torchgpipe
training: bool
......
......@@ -106,7 +106,7 @@ def parity3d_checkpoint_syncbn():
x = torch.randn(4, 3, 4, 4, 4).cuda() * rank
torch_bn = torch.nn.SyncBatchNorm(3).cuda()
fs_bn = SyncBatchNorm(3).cuda()
fs_bn = checkpoint_wrapper(fs_bn, maintain_forward_counter=True)
fs_bn = checkpoint_wrapper(fs_bn)
check_parity_ddp(torch_bn, fs_bn, x)
......
......@@ -70,7 +70,15 @@ class Model2(nn.Module):
def _create_model(
with_model2, with_sync_bn, with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter
with_model2,
with_sync_bn,
with_fsdp,
with_checkpoint,
mixed_precision,
flatten,
wrap_bn,
fp32_reduce_scatter,
bucket_cap_mb,
):
model = Model2() if with_model2 else Model()
fsdp_config = None
......@@ -84,33 +92,32 @@ def _create_model(
"force_input_to_fp32": True, # SyncBN needs this.
}
if with_fsdp and wrap_bn:
model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False, fsdp_config=fsdp_config)
model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False, fsdp_config=fsdp_config)
if with_model2:
model.block3 = auto_wrap_bn(model.block3, single_rank_pg=False, fsdp_config=fsdp_config)
if with_checkpoint:
model.block2 = checkpoint_wrapper(model.block2)
if with_model2:
model.block3 = checkpoint_wrapper(model.block3)
if with_fsdp:
if wrap_bn:
model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False, fsdp_config=fsdp_config)
model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False, fsdp_config=fsdp_config)
if with_model2:
model.block3 = auto_wrap_bn(model.block3, single_rank_pg=False, fsdp_config=fsdp_config)
if with_checkpoint:
model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=True)
if with_model2:
model.block3 = checkpoint_wrapper(model.block3, maintain_forward_counter=True)
with enable_wrap(
wrapper_cls=FSDP,
flatten_parameters=flatten,
mixed_precision=mixed_precision,
compute_dtype=torch.float32,
fp32_reduce_scatter=fp32_reduce_scatter,
bucket_cap_mb=bucket_cap_mb,
):
model.block1 = wrap(model.block1)
model.block2 = wrap(model.block2)
if with_model2:
model.block3 = wrap(model.block3)
model.head = wrap(model.head)
else:
if with_checkpoint:
model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=False)
if with_model2:
model.block3 = checkpoint_wrapper(model.block3, maintain_forward_counter=False)
return model
......@@ -126,6 +133,7 @@ def _distributed_worker(
flatten,
wrap_bn,
fp32_reduce_scatter,
bucket_cap_mb,
):
filename, filename_rpc = files[:2]
filename_loss = files[2:]
......@@ -155,7 +163,15 @@ def _distributed_worker(
batch = [x.half() for x in batch]
model = _create_model(
with_model2, with_sync_bn, with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter
with_model2,
with_sync_bn,
with_fsdp,
with_checkpoint,
mixed_precision,
flatten,
wrap_bn,
fp32_reduce_scatter,
bucket_cap_mb,
)
model = model.cuda()
......@@ -166,6 +182,7 @@ def _distributed_worker(
mixed_precision=mixed_precision,
compute_dtype=torch.float32,
fp32_reduce_scatter=fp32_reduce_scatter,
bucket_cap_mb=bucket_cap_mb,
)
model.set_gradient_divide_factors(1.0, 2.0, True)
no_sync_context = contextlib.suppress()
......@@ -228,6 +245,7 @@ def _get_cached_results(
flatten,
wrap_bn,
fp32_reduce_scatter,
bucket_cap_mb,
):
""" Cache the training to save time. For DDP, flatten, wrap_bn etc. doesn't matter, so
the results can be cached.
......@@ -247,6 +265,7 @@ def _get_cached_results(
flatten,
wrap_bn,
fp32_reduce_scatter,
bucket_cap_mb,
)
global _result_cache
if key not in _result_cache:
......@@ -265,6 +284,7 @@ def _get_cached_results(
flatten,
wrap_bn,
fp32_reduce_scatter,
bucket_cap_mb,
),
nprocs=world_size,
)
......@@ -310,39 +330,66 @@ def test_multiple_forward_checkpoint(precision, flatten, wrap_bn, model_type, bn
world_size = 2
expected_losses = None
# Ensure ddp == ddp+ckpt == fsdp == fsdp+ckpt.
# Ensure ddp == fsdp when modules are called multiple times per forward pass with/without checkpointing, forward
# counters and reducer bucketing.
#
# The bucketing check exists because the asynchronous gradient reduction it induces can interact with multiple
# forward passes in complex ways. For example, in the midst of a sharded backward pass, `parameter.grad` may only be
# `None` or an unsharded gradient tensor. The sharded tensor is then set at the end of the backwards pass. But a
# unit test with bucketing enabled might not catch violations of this invariant. For very small models, like the
# kind used in this unit test, bucketing will delay gradient reduction until after all the gradient computation is
# done. If the reduction incorrectly sets `.grad` to the _sharded_ variant, the test might not fail, since the
# gradient computations have already happened. Toggling bucketing helps verify that gradient reduction and
# computation interact correctly.
combinations = []
for with_fsdp in [False, True]:
for with_checkpoint in [False, True]:
if not with_fsdp and with_checkpoint:
continue
final_losses = _get_cached_results(
world_size,
with_model2,
with_sync_bn,
with_fsdp,
with_checkpoint,
mixed_precision,
flatten,
wrap_bn,
fp32_reduce_scatter,
)
if expected_losses is None:
expected_losses = final_losses
else:
print(f"checking: fsdp {with_fsdp} ckpt {with_checkpoint} with ddp+no_ckpt")
def check(exp, res):
assert list(exp.keys()) == list(res.keys()), f"{list(exp.keys())} vs. {list(res.keys())}"
rtol = 1e-4
atol = 1e-5
if with_model2 and mixed_precision and torch_version() >= (1, 9, 0):
# On CI, with longer model2, mixed precsion and 1.9, even ddp vs. ddp+ckpt has
# larger errors.
rtol = 1e-3
atol = 1e-4
for key in exp.keys():
exp_loss = exp[key]
res_loss = res[key]
torch.testing.assert_allclose(exp_loss, res_loss, rtol=rtol, atol=atol)
check(expected_losses, final_losses)
for with_bucketing in [False, True]:
if not with_fsdp and with_bucketing:
continue
combinations.append((with_fsdp, with_checkpoint, with_bucketing))
print("")
print("Testing the following configurations:")
for with_fsdp, with_checkpoint, with_bucketing in combinations:
print(f" fsdp {with_fsdp} ckpt {with_checkpoint} bucketing {with_bucketing}")
for with_fsdp, with_checkpoint, with_bucketing in combinations:
if with_bucketing:
bucket_cap_mb = 25
else:
bucket_cap_mb = 0
final_losses = _get_cached_results(
world_size,
with_model2,
with_sync_bn,
with_fsdp,
with_checkpoint,
mixed_precision,
flatten,
wrap_bn,
fp32_reduce_scatter,
bucket_cap_mb,
)
if expected_losses is None:
expected_losses = final_losses
else:
print(f"checking: fsdp {with_fsdp} ckpt {with_checkpoint} bucketing {with_bucketing} with ddp+no_ckpt")
def check(exp, res):
assert list(exp.keys()) == list(res.keys()), f"{list(exp.keys())} vs. {list(res.keys())}"
rtol = 1e-4
atol = 1e-5
if with_model2 and mixed_precision and torch_version() >= (1, 9, 0):
# On CI, with longer model2, mixed precsion and 1.9, even ddp vs. ddp+ckpt has
# larger errors.
rtol = 1e-3
atol = 1e-4
for key in exp.keys():
exp_loss = exp[key]
res_loss = res[key]
torch.testing.assert_allclose(exp_loss, res_loss, rtol=rtol, atol=atol)
check(expected_losses, final_losses)
......@@ -71,12 +71,10 @@ def _train_step(model, optim, expected_param_shapes):
# Create input and run forward pass.
input = torch.randn(2, 3).cuda()
loss = model(input).sum()
_check_fwd_counter(model, 1)
_check_params(model, expected_param_shapes)
# Run backward pass.
loss.backward()
_check_fwd_counter(model, 0)
_check_params(model, expected_param_shapes)
# Finally, take a step.
......@@ -90,7 +88,6 @@ def _eval_step(model, optim, expected_param_shapes):
with torch.no_grad():
input = torch.randn(2, 3).cuda()
model(input).sum()
_check_fwd_counter(model, 0)
_check_params(model, expected_param_shapes)
......@@ -104,20 +101,11 @@ def _check_params(model, expected_param_shapes):
), f"Parameter {key} should have shape {expected_shape}, but found shape {current_shape}"
def _check_fwd_counter(model, expected_value):
current_value = model._fpw_module.ffn[1]._fsdp_wrapped_module.module._checkpoint_fwd_counter
assert (
current_value == expected_value
), f"forward counter of checkpointed submodule should be {expected_value}, but found {current_value}"
class SimpleModuleWithCheckpointing(nn.Module):
def __init__(self):
super().__init__()
self.ffn = nn.Sequential(
nn.Linear(3, 3),
FullyShardedDataParallel(checkpoint_wrapper(nn.Linear(3, 3), maintain_forward_counter=True)),
nn.Linear(3, 3),
nn.Linear(3, 3), FullyShardedDataParallel(checkpoint_wrapper(nn.Linear(3, 3))), nn.Linear(3, 3),
)
def forward(self, x):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment