"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "faf8f1cc06995c56fe06237fd2e485ab7b571546"
Unverified Commit a1612d79 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix]: let FSDP handle model with multiple forward pass and checkpoint (#621)



* [fix]: let FSDP handle model with multiple forward pass and checkpoint

* try CI again

* save

* save

* fixed case with bn

* minor

* add the new file

* minor

* added test of a single case, runtime is about 50s

* enable all 8 test cases

* cleanup

* cleanup

* skip flatten case with 1.6 and 1.7

* minor
Co-authored-by: default avatarMin Xu <min.xu@acm.org>
parent 5cddaea4
...@@ -115,9 +115,15 @@ class FullyShardedDataParallel(nn.Module): ...@@ -115,9 +115,15 @@ class FullyShardedDataParallel(nn.Module):
an assert on the backward pass. The solution is to leave some parameters an assert on the backward pass. The solution is to leave some parameters
to the outer FSDP. to the outer FSDP.
.. warning::
If activation checkpointing is used with FSDP, it is strongly encouraged
to use ``checkpoint_wrapper`` function from FairScale instead of the
``checkpoint`` function from PyTorch.
Args: Args:
module (nn.Module): module (nn.Module):
module to checkpoint module to be wrapped with FullyShardedDataParallel.
process_group (Optional): process_group (Optional):
process group for sharding process group for sharding
reshard_after_forward (bool, Optional): reshard_after_forward (bool, Optional):
...@@ -207,7 +213,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -207,7 +213,7 @@ class FullyShardedDataParallel(nn.Module):
self.no_broadcast_optim_state = no_broadcast_optim_state self.no_broadcast_optim_state = no_broadcast_optim_state
self.state_dict_device = state_dict_device or self.compute_device self.state_dict_device = state_dict_device or self.compute_device
self.gradient_predivide_factor: int = self.get_gradient_predivide_factor(self.world_size) self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
self.numel_padded_per_param: List[int] = [] self.numel_padded_per_param: List[int] = []
...@@ -275,11 +281,31 @@ class FullyShardedDataParallel(nn.Module): ...@@ -275,11 +281,31 @@ class FullyShardedDataParallel(nn.Module):
f"FSDP.__init__(done): total_init_time: {(init_end - init_start): .4f} num_params: {(sum(p.numel() for p in self.params))}" f"FSDP.__init__(done): total_init_time: {(init_end - init_start): .4f} num_params: {(sum(p.numel() for p in self.params))}"
) )
def get_gradient_predivide_factor(self, world_size: int) -> int: # Flag to guard multiple pre-forward hook being executed per iteration.
factor = 1 # This is reset at the end of the backward pass.
self._pre_backward_hook_has_run = False
def _get_gradient_predivide_factor(self, world_size: int) -> float:
factor: int = 1
while world_size % factor == 0 and world_size / factor > factor: while world_size % factor == 0 and world_size / factor > factor:
factor = factor * 2 factor *= 2
return factor return float(factor)
def set_gradient_divide_factors(self, pre: float, post: float, recursive: bool) -> None:
"""Allowing user to override the pre and post divide factors.
Args:
pre (float): divide factor before the reduction.
post (float): divide factor after the reduction.
recursive (bool): recursively set it for all child FSDP instances or not.
"""
self.assert_state(TrainingState.IDLE)
if recursive:
for module in self.modules():
if isinstance(module, FullyShardedDataParallel) and module != self:
module.set_gradient_divide_factors(pre, post, False)
self.gradient_predivide_factor = pre
self.gradient_postdivide_factor = post
@property @property
def module(self) -> nn.Module: def module(self) -> nn.Module:
...@@ -943,7 +969,13 @@ class FullyShardedDataParallel(nn.Module): ...@@ -943,7 +969,13 @@ class FullyShardedDataParallel(nn.Module):
self._use_fp32_param_shard() self._use_fp32_param_shard()
# Register pre-backward hooks to all-gather the params for the backward # Register pre-backward hooks to all-gather the params for the backward
# pass (if needed). # pass (if output's grad was needed). This won't register anything if
# we are in eval mode.
#
# Some model does forward pass multiple times, we need to register the
# pre-backward hook on every output since the last output's hook has to
# fire first to setup for backward. However, we use ``self._pre_backward_hook_has_run``
# to prevent repeated overhead from multiple hook callbacks.
outputs = self._register_pre_backward_hooks(outputs) outputs = self._register_pre_backward_hooks(outputs)
# Done with a forward pass. # Done with a forward pass.
...@@ -953,16 +985,18 @@ class FullyShardedDataParallel(nn.Module): ...@@ -953,16 +985,18 @@ class FullyShardedDataParallel(nn.Module):
def _register_pre_backward_hooks(self, outputs: Any) -> Any: def _register_pre_backward_hooks(self, outputs: Any) -> Any:
"""Register pre-backward hook to run before the wrapped module's """Register pre-backward hook to run before the wrapped module's
backward. Hooks should be attached to all outputs from the forward.""" backward. Hooks should be attached to all outputs from the forward.
Returns:
outputs: new outputs with hooks registered if they requires gradient.
"""
if not torch.is_grad_enabled(): if not torch.is_grad_enabled():
return outputs # don't register hooks if grad isn't enabled return outputs # don't register hooks if grad isn't enabled
pre_backward_hook_has_run = [False]
def _pre_backward_hook(*unused: Any) -> None: def _pre_backward_hook(*unused: Any) -> None:
if pre_backward_hook_has_run[0]: if self._pre_backward_hook_has_run:
return # only run once return # only run once (from multiple outputs or multiple forward passes)
pre_backward_hook_has_run[0] = True self._pre_backward_hook_has_run = True
# Start of a backward pass. # Start of a backward pass.
self.assert_state([TrainingState.IDLE, TrainingState.BACKWARD_PRE]) self.assert_state([TrainingState.IDLE, TrainingState.BACKWARD_PRE])
...@@ -1062,13 +1096,27 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1062,13 +1096,27 @@ class FullyShardedDataParallel(nn.Module):
the local optimizer only sees the relevant parameter shard. the local optimizer only sees the relevant parameter shard.
""" """
# First hook callback will see PRE state. If we have multiple params, # First hook callback will see PRE state. If we have multiple params,
# then subsequent hook callbacks will see POST state. # then subsequent hook callbacks will see POST state. When checkpoint
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST]) # fwd counter is used, IDLE is also possible since the pre-backward hook
# is not triggered (see ``auto_wrap_bn`` below, we have to use
# FSDP(checkpoint(conv, FSDP(bn), ...)), with reshard_after_forward=False).
if hasattr(self, "_checkpoint_fwd_counter"):
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST, TrainingState.IDLE])
else:
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
self.training_state = TrainingState.BACKWARD_POST self.training_state = TrainingState.BACKWARD_POST
if param.grad is None: if param.grad is None:
return return
if param.grad.requires_grad: if param.grad.requires_grad:
raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require grad") raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require gradients")
# If this is a checkpointed module, we check if the following
# counter reaches 0. If not, it is not the final backward call
# for this module yet. Therefore, we early return in that case.
if hasattr(self._fsdp_wrapped_module, "_checkpoint_fwd_counter"):
if self._fsdp_wrapped_module._checkpoint_fwd_counter != 0:
return
if self._require_backward_grad_sync or self.reshard_after_forward: if self._require_backward_grad_sync or self.reshard_after_forward:
# Free full params. As a special case, we don't free the full params # Free full params. As a special case, we don't free the full params
...@@ -1200,6 +1248,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1200,6 +1248,7 @@ class FullyShardedDataParallel(nn.Module):
for m in self.modules(): # includes self for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel): if isinstance(m, FullyShardedDataParallel):
_remove_shard_bwd_hook(m) _remove_shard_bwd_hook(m)
m._pre_backward_hook_has_run = False
if m._has_params: if m._has_params:
if any(p.requires_grad for p in m.params): if any(p.requires_grad for p in m.params):
m.assert_state(TrainingState.BACKWARD_POST) m.assert_state(TrainingState.BACKWARD_POST)
...@@ -1395,8 +1444,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1395,8 +1444,8 @@ class FullyShardedDataParallel(nn.Module):
# In case we are failing in the context of autograd hook, asserting # In case we are failing in the context of autograd hook, asserting
# may not generate useful msg. So, let's print it to be sure. # may not generate useful msg. So, let's print it to be sure.
if self.rank == 0: if self.rank == 0:
print(self) print(f"Asserting FSDP instance is: {self}")
print(msg) print(f"ERROR: {msg}")
traceback.print_stack() traceback.print_stack()
raise ValueError(msg) raise ValueError(msg)
...@@ -1543,7 +1592,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1543,7 +1592,7 @@ class FullyShardedDataParallel(nn.Module):
v_shard = v[0] if self.rank >= len(v) else v[self.rank] v_shard = v[0] if self.rank >= len(v) else v[self.rank]
assert ou.is_singleton_tensor(v_shard) assert ou.is_singleton_tensor(v_shard)
else: else:
v_shard = v # dont shard entries that are not tensors v_shard = v # don't shard entries that are not tensors
full_optim_state_dict["state"][id][k] = v_shard full_optim_state_dict["state"][id][k] = v_shard
return full_optim_state_dict return full_optim_state_dict
...@@ -1686,6 +1735,10 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group: ...@@ -1686,6 +1735,10 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group:
"process_group": pg, "process_group": pg,
"mixed_precision": False, # Keep the weights in FP32. "mixed_precision": False, # Keep the weights in FP32.
"flatten_parameters": False, # Do not flatten. "flatten_parameters": False, # Do not flatten.
# Reshard==False is good for performance. When FSDP(checkpoint(FSDP(bn))) is used, this
# **must** be False because BN's FSDP wrapper's pre-backward callback isn't called
# within the checkpoint's outer backward when multiple forward passes are used.
"reshard_after_forward": False,
} }
with enable_wrap(wrap_bn_only_policy, **fsdp_config): with enable_wrap(wrap_bn_only_policy, **fsdp_config):
......
...@@ -15,10 +15,12 @@ import torch.utils.checkpoint as torch_checkpoint ...@@ -15,10 +15,12 @@ import torch.utils.checkpoint as torch_checkpoint
from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors
from .misc import patch_batchnorm from .misc import dec_counter, inc_counter, init_counter, patch_batchnorm
def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False) -> nn.Module: def checkpoint_wrapper(
module: nn.Module, offload_to_cpu: bool = False, maintain_forward_counter: bool = False
) -> nn.Module:
""" """
A friendlier wrapper for performing activation checkpointing. A friendlier wrapper for performing activation checkpointing.
...@@ -58,16 +60,23 @@ def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False) -> nn.Mo ...@@ -58,16 +60,23 @@ def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False) -> nn.Mo
Args: Args:
module (nn.Module): module (nn.Module):
The module to be wrapped The module to be wrapped
offload_to_cpu (Optional, bool): offload_to_cpu (bool):
Whether to offload activations to CPU. Whether to offload activations to CPU.
maintain_forward_counter (bool):
If True, maintain a forward counter per inner module. The counter will first
increases in forward calls of outer forward pass and then decreases in the
forward calls of outer backward pass. It is used by FullyShardedDataParallel.
Returns: Returns:
(nn.Module): (nn.Module):
Wrapped module Wrapped module
""" """
# Patch the batchnorm layers in case there are any. # Patch the batchnorm layers in case there are any in this module.
patch_batchnorm(module) patch_batchnorm(module)
if maintain_forward_counter:
init_counter(module)
# The use of weakref here is to prevent creating a ref cycle: m -> m.forward -> m. # The use of weakref here is to prevent creating a ref cycle: m -> m.forward -> m.
# When such cycle exists, gc won't collect the module when the module is freed. # When such cycle exists, gc won't collect the module when the module is freed.
# That causes GPU memory to be leaked. See the unit test for how we catch that. # That causes GPU memory to be leaked. See the unit test for how we catch that.
...@@ -168,6 +177,8 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -168,6 +177,8 @@ class CheckpointFunction(torch.autograd.Function):
with torch.no_grad(): with torch.no_grad():
unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args)
outputs = run_function(*unpacked_args, **unpacked_kwargs) outputs = run_function(*unpacked_args, **unpacked_kwargs)
the_module = unpacked_args[0]
inc_counter(the_module)
if not isinstance(outputs, torch.Tensor): if not isinstance(outputs, torch.Tensor):
# Autograd Functions don't like non-Tensor outputs. We can split the # Autograd Functions don't like non-Tensor outputs. We can split the
...@@ -200,6 +211,8 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -200,6 +211,8 @@ class CheckpointFunction(torch.autograd.Function):
unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs) unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs)
outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs)
tensor_outputs, _ = split_non_tensors(outputs) tensor_outputs, _ = split_non_tensors(outputs)
the_module = unpacked_args[0]
dec_counter(the_module)
# Set the states back to what it was at the start of this function. # Set the states back to what it was at the start of this function.
set_rng_state(bwd_rng_state) set_rng_state(bwd_rng_state)
......
...@@ -27,7 +27,6 @@ def patch_batchnorm(module: nn.Module) -> List: ...@@ -27,7 +27,6 @@ def patch_batchnorm(module: nn.Module) -> List:
(list): (list):
A list of hook handles, late can be freed. A list of hook handles, late can be freed.
""" """
hooks = []
def pre_forward(module: _BatchNorm, input: Tensor) -> None: def pre_forward(module: _BatchNorm, input: Tensor) -> None:
if torch.is_grad_enabled(): if torch.is_grad_enabled():
...@@ -40,6 +39,7 @@ def patch_batchnorm(module: nn.Module) -> List: ...@@ -40,6 +39,7 @@ def patch_batchnorm(module: nn.Module) -> List:
return return
module.track_running_stats = module._track_running_stats_backup module.track_running_stats = module._track_running_stats_backup
hooks = []
for name, child in module.named_modules(): for name, child in module.named_modules():
# _BatchNorm is base for bn1d, bn2d, bn3d and sync_bn, apex_sync_bn, etc. # _BatchNorm is base for bn1d, bn2d, bn3d and sync_bn, apex_sync_bn, etc.
if isinstance(child, _BatchNorm): if isinstance(child, _BatchNorm):
...@@ -48,3 +48,28 @@ def patch_batchnorm(module: nn.Module) -> List: ...@@ -48,3 +48,28 @@ def patch_batchnorm(module: nn.Module) -> List:
post_handle = child.register_forward_hook(post_forward) post_handle = child.register_forward_hook(post_forward)
hooks += [pre_handle, post_handle] hooks += [pre_handle, post_handle]
return hooks return hooks
def init_counter(module: nn.Module) -> None:
"""Add a checkpoint forward pass counter to a module and all its child FSDP modules.
``inc_counter`` and ``dec_counter`` are used together with this to maintain counters
for FSDP to use in case of multiple forward pass and checkpoint being used at the same time.
"""
for mod in module.modules():
mod._checkpoint_fwd_counter = 0
def _add_counter(module: nn.Module, value: int) -> None:
if not hasattr(module, "_checkpoint_fwd_counter"):
return
for mod in module.modules():
mod._checkpoint_fwd_counter += value
def inc_counter(module: nn.Module) -> None:
_add_counter(module, 1)
def dec_counter(module: nn.Module) -> None:
_add_counter(module, -1)
...@@ -26,6 +26,7 @@ you see fit, but refrain from ad-hoc test utils within the different feature set ...@@ -26,6 +26,7 @@ you see fit, but refrain from ad-hoc test utils within the different feature set
relative imports. relative imports.
""" """
import contextlib
import functools import functools
import inspect import inspect
import logging import logging
...@@ -35,7 +36,7 @@ import random ...@@ -35,7 +36,7 @@ import random
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Tuple, Union
import numpy import numpy
import pytest import pytest
...@@ -645,3 +646,15 @@ def rmf(filename: str) -> None: ...@@ -645,3 +646,15 @@ def rmf(filename: str) -> None:
os.remove(filename) os.remove(filename)
except FileNotFoundError: except FileNotFoundError:
pass pass
@contextlib.contextmanager
def temp_files_ctx(num: int) -> Generator:
""" A context to get tempfiles and ensure they are cleaned up. """
files = [tempfile.mkstemp()[1] for _ in range(num)]
yield tuple(files)
# temp files could have been removed, so we use rmf.
for name in files:
rmf(name)
...@@ -108,6 +108,8 @@ class Module(Generic[T_co]): ...@@ -108,6 +108,8 @@ class Module(Generic[T_co]):
def extra_repr(self) -> str: ... def extra_repr(self) -> str: ...
#MODIFIED BY TORCHGPIPE # This is added by checkpoint_wrapper
_checkpoint_fwd_counter: int
# This is added torchgpipe
training: bool training: bool
#END
tests/nn/misc/test_flatten_params_wrapper.py tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp_freezing_weights.py
tests/nn/data_parallel/test_fsdp_multiple_wrapping.py tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
tests/nn/data_parallel/test_fsdp_freezing_weights.py
tests/nn/data_parallel/test_fsdp.py
tests/nn/misc/test_flatten_params_wrapper.py
...@@ -118,7 +118,7 @@ def temp_files(): ...@@ -118,7 +118,7 @@ def temp_files():
@skip_if_single_gpu @skip_if_single_gpu
def tests1(temp_files): def test_freezing_weights(temp_files):
world_size = 2 world_size = 2
# DDP # DDP
fsdp = False fsdp = False
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test FSDP with multiple forward pass + checkpoint. """
import contextlib
import pickle
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim
from fairscale.nn import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import auto_wrap_bn
from fairscale.nn.wrap import enable_wrap, wrap
from fairscale.utils.testing import (
dist_init,
objects_are_equal,
skip_if_single_gpu,
teardown,
temp_files_ctx,
torch_version,
)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.block1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3), nn.BatchNorm2d(64), nn.ReLU(inplace=True),)
self.block2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(),
)
self.head = nn.Linear(128, 10)
def forward(self, x):
if isinstance(x, torch.Tensor):
return self.head(self.block2(self.block1(x)))
elif isinstance(x, list):
ys = [self.head(self.block2(self.block1(e))) for e in x]
return torch.cat(ys, dim=0)
def create_model(with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter):
model = Model()
if with_fsdp:
if wrap_bn:
model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False)
model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False)
if with_checkpoint:
model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=True)
with enable_wrap(
wrapper_cls=FSDP,
flatten_parameters=flatten,
mixed_precision=mixed_precision,
compute_dtype=torch.float32,
fp32_reduce_scatter=fp32_reduce_scatter,
):
model.block1 = wrap(model.block1)
model.block2 = wrap(model.block2)
model.head = wrap(model.head)
else:
if with_checkpoint:
model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=False)
return model
def _distributed_worker(
gpu_id, world_size, with_fsdp, with_checkpoint, files, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter
):
filename, filename_rpc = files[:2]
filename_loss = files[2:]
torch.cuda.set_device(gpu_id)
rank = gpu_id
result = dist_init(rank, world_size, filename, filename_rpc)
assert result, "Dist init failed"
# use False below to debug since error msg is not as good with cudnn.
torch.backends.cudnn.enabled = True
# these make things deterministic.
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Ensure we have multiple forward passes.
batch = [
torch.randn(size=(2, 3, 224, 224)).cuda(),
torch.randn(size=(2, 3, 96, 96)).cuda(),
torch.randn(size=(2, 3, 96, 96)).cuda(),
]
if mixed_precision and not with_fsdp:
batch = [x.half() for x in batch]
model = create_model(with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter)
model = model.cuda()
if with_fsdp:
model = FSDP(
model,
flatten_parameters=flatten,
mixed_precision=mixed_precision,
compute_dtype=torch.float32,
fp32_reduce_scatter=fp32_reduce_scatter,
)
model.set_gradient_divide_factors(1.0, 2.0, True)
no_sync_context = contextlib.suppress()
else:
# With DDP, we need no_sync and manual gradient reduction below because
# it can't handle multiple forward pass + checkpointing otherwise.
model = DistributedDataParallel(model, device_ids=[gpu_id])
no_sync_context = model.no_sync()
mp_context = contextlib.suppress()
if mixed_precision:
mp_context = torch.cuda.amp.autocast(enabled=True)
if gpu_id == 0:
print(model)
target = torch.LongTensor([0, 1, 2, 3, 4, 5]).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
losses = {}
i = 0
with no_sync_context:
for iteration in range(3):
with mp_context:
out = model(batch)
loss = criterion(out, target)
print("Loss", iteration, ":", loss.item())
losses[f"iter_{i}"] = loss
i += 1
optimizer.zero_grad()
loss.backward()
# Manual grad reduction, no autocast.
if not with_fsdp:
for p in model.parameters():
dist.all_reduce(p.grad.data)
p.grad.data.div_(2.0)
# Stepping, no autocast
optimizer.step()
# Due to dist.all_reduce code block above with ddp.no_sync, we seem to hit a bug
# in DDP where tensor.cpu() and torch.save() calls both hang. FSDP is not affected.
# Therefore, we have to compare losses here instead of states.
with open(filename_loss[rank], "wb") as f:
pickle.dump(losses, f)
teardown()
@skip_if_single_gpu
@pytest.mark.parametrize("precision", ["full", "mixed"])
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
@pytest.mark.parametrize("wrap_bn", ["auto_wrap_bn", "no_auto_wrap_bn"])
def test_multiple_forward_checkpoint(precision, flatten, wrap_bn):
mixed_precision = precision == "mixed"
flatten = flatten == "flatten"
wrap_bn = wrap_bn == "auto_wrap_bn"
fp32_reduce_scatter = True if mixed_precision else None
if torch_version() < (1, 8, 0) and flatten:
# 1.6 and 1.7 throws this error:
# RuntimeError: Trying to backward through the graph a second time, but the saved
# intermediate results have already been freed. Specify retain_graph=True when calling
# backward the first time.
pytest.skip("older pytorch throws error when flatten is used")
world_size = 2
expected_losses = None
# Ensure ddp == ddp+ckpt == fsdp == fsdp+ckpt.
for with_fsdp in [False, True]:
for with_checkpoint in [False, True]:
# Get 4 files: 2 for dist_init and 2 for each rank to save the losses.
with temp_files_ctx(num=2 + world_size) as temp_files:
mp.spawn(
_distributed_worker,
(
world_size,
with_fsdp,
with_checkpoint,
temp_files,
mixed_precision,
flatten,
wrap_bn,
fp32_reduce_scatter,
),
nprocs=world_size,
)
final_losses = {}
for rank in range(world_size):
with open(temp_files[2 + rank], "rb") as f:
final_losses[f"rank_{rank}"] = pickle.load(f)
if expected_losses is None:
expected_losses = final_losses
else:
print(f"fsdp: {with_fsdp} ckpt: {with_checkpoint}")
assert objects_are_equal(expected_losses, final_losses, raise_exception=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment