Unverified Commit 0233efca authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[test] FSDP: check with ddp parity with conv + bn (#549)

- added DDP equivalency test
- added rmf, state_dict_norm functions to testing utils
- added more debugging output to objects_are_equal
parent a2b11de4
......@@ -400,33 +400,53 @@ class GPT2(Base):
return self.clf_head(h), logits
def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: Optional[str] = None) -> bool:
"""
Test that two objects are equal. Tensors are compared to ensure matching
size, dtype, device and values.
"""
if type(a) is not type(b):
if raise_exception:
raise ValueError(f"type mismatch {type(a)} vs. {type(b)}")
return False
if isinstance(a, dict):
if set(a.keys()) != set(b.keys()):
if raise_exception:
raise ValueError(f"keys mismatch {a.keys()} vs. {b.keys()}")
return False
for k in a.keys():
if not objects_are_equal(a[k], b[k], raise_exception):
if not objects_are_equal(a[k], b[k], raise_exception, k):
return False
return True
elif isinstance(a, (list, tuple, set)):
if len(a) != len(b):
if raise_exception:
raise ValueError(f"length mismatch {len(a)} vs. {len(b)}")
return False
return all(objects_are_equal(x, y, raise_exception) for x, y in zip(a, b))
elif torch.is_tensor(a):
try:
torch.testing.assert_allclose(a, b)
# assert_allclose doesn't strictly test shape, dtype and device
shape_dtype_device_match = a.size() == b.size() and a.dtype == b.dtype and a.device == b.device
assert shape_dtype_device_match
if not shape_dtype_device_match:
if raise_exception:
msg = f"sizes: {a.size()} vs. {b.size()}, "
msg += f"types: {a.dtype} vs. {b.dtype}, "
msg += f"device: {a.device} vs. {b.device}"
raise AssertionError(msg)
else:
return False
# assert_allclose.
torch.testing.assert_allclose(a, b)
return True
except (AssertionError, RuntimeError) as e:
if raise_exception:
if dict_key and isinstance(e, AssertionError):
# Add dict key to the assertion error.
msg = e.args[0]
new_msg = f"For dict key '{dict_key}': {msg}"
raise AssertionError(new_msg)
else:
raise e
else:
return False
......@@ -582,3 +602,21 @@ class SGDWithPausingCompute(torch.optim.SGD):
param *= 1.0 + self.rank / 10.0
return loss
def state_dict_norm(state: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute the norm from a state_dict for simple comparison."""
norm = torch.zeros(1)
for v in state.values():
if not v.is_floating_point():
v = v.float()
norm += v.norm()
return norm
def rmf(filename: str) -> None:
"""Remove a file like rm -f."""
try:
os.remove(filename)
except FileNotFoundError:
pass
......@@ -922,7 +922,7 @@ class Tensor:
# TODO: fill in the types for these, or otherwise figure out some
# way to not have to write these out again...
def nonzero(self, *, as_tuple=True): ...
def norm(self, p="fro", dim=None, keepdim=False): ...
def norm(self, p="fro", dim=None, keepdim=False, out=None, dtype=None) -> Tensor: ...
def stft(self, n_fft, hop_length=None, win_length=None, window=None,
center=True, pad_mode='reflect', normalized=False, onesided=True): ...
def split(self, split_size, dim=0) -> Tuple[Tensor, ...]: ...
......
......@@ -9,6 +9,7 @@
""" Test FSDP with regnet-like model. """
import contextlib
import random
import tempfile
......@@ -16,20 +17,23 @@ import pytest
import torch
import torch.multiprocessing as mp
from torch.nn import BatchNorm2d, Conv2d, Module, SyncBatchNorm
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, torch_version
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
result = dist_init(rank, world_size, tempfile_name, unused)
assert result, "Dist init failed"
assert isinstance(fsdp_config, dict), str(fsdp_config)
class Model(Module):
from fairscale.utils.testing import (
dist_init,
objects_are_equal,
rmf,
skip_if_single_gpu,
state_dict_norm,
teardown,
torch_version,
)
class Model(Module):
def __init__(self):
super().__init__()
# TODO (Min): for now, we just test pytorch sync_bn here.
......@@ -42,9 +46,112 @@ def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
x = self.bn(x)
return x
# TODO (Min): check DDP equivalency.
# We get a bit fancy here. Since the scope is `module`, this is run only
# once no matter how many tests variations for FSDP are requested to run
# to compare with the DDP reference. For example, a single DDP
# reference run is needed for both flatten and non-flatten param FSDP.
#
# Note, this runs DDP twice with and without mixed precision and asserts
# the resulting weights are different.
#
# This fixture captures and returns:
#
# - model state_dict before training
# - model data inputs
# - model state_dict after training
@pytest.fixture(scope="module")
def ddp_ref():
# Get a reference model state
model = Model()
state_before = model.state_dict()
# Get reference inputs per rank.
world_size = 2
iterations = 100
inputs = [[]] * world_size
for rank in range(world_size):
for i in range(iterations):
inputs[rank].append(torch.rand(2, 2, 2, 2))
# Run DDP training twice, fp and mp.
for precision in ["full", "mixed"]:
temp_file_name = tempfile.mkstemp()[1]
unused = tempfile.mkstemp()[1]
rank_0_output = tempfile.mkstemp()[1]
try:
fsdp_config = None # This means we use DDP in _test_func.
mp.spawn(
_test_func,
args=(
world_size,
fsdp_config,
precision == "mixed",
temp_file_name,
unused,
state_before,
inputs,
rank_0_output,
None,
),
nprocs=world_size,
join=True,
)
if precision == "full":
state_after_fp = torch.load(rank_0_output)
else:
state_after_mp = torch.load(rank_0_output)
finally:
rmf(temp_file_name)
rmf(unused)
rmf(rank_0_output)
assert state_dict_norm(state_after_fp) != state_dict_norm(state_after_mp)
return state_before, inputs, state_after_fp, state_after_mp
# A fixture to get tempfiles and ensure they are cleaned up.
@pytest.fixture()
def temp_files():
temp_file_name = tempfile.mkstemp()[1]
unused = tempfile.mkstemp()[1]
yield temp_file_name, unused
# temp files could have been removed, so we use rmf.
rmf(temp_file_name)
rmf(unused)
def _test_func(
rank,
world_size,
fsdp_config,
ddp_mixed_precision,
tempfile_name,
unused,
state_before,
inputs,
rank_0_output,
state_after,
):
result = dist_init(rank, world_size, tempfile_name, unused)
assert result, "Dist init failed"
ddp = True
if fsdp_config:
ddp = False
assert isinstance(fsdp_config, dict), str(fsdp_config)
model = Model()
model.load_state_dict(state_before)
model = model.cuda()
if ddp:
model = SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[rank])
else:
# Note, different rank may wrap in different order due to different random
# seeds. But results should be the same.
if random.randint(0, 1) == 0:
......@@ -58,37 +165,60 @@ def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
model = FSDP(model, **fsdp_config).cuda()
optim = SGD(model.parameters(), lr=0.1)
for _ in range(3):
in_data = torch.rand(2, 2, 2, 2).cuda()
in_data.requires_grad = True
for in_data in inputs[rank]:
in_data = in_data.cuda()
context = contextlib.suppress()
if ddp and ddp_mixed_precision:
in_data = in_data.half()
context = torch.cuda.amp.autocast(enabled=True)
with context:
out = model(in_data)
out.sum().backward()
loss = out.sum()
loss.backward()
optim.step()
optim.zero_grad()
if ddp:
# Save the rank 0 state_dict to the output file.
if rank == 0:
state_after = model.module.cpu().state_dict()
torch.save(state_after, rank_0_output)
else:
model.assert_state(TrainingState.IDLE)
# Ensure final state equals to the state_after.
fsdp_state = model.state_dict()
# Move tensors to CPU to compare numerics.
for k, v in fsdp_state.items():
fsdp_state[k] = v.cpu()
assert objects_are_equal(state_after, fsdp_state, raise_exception=True)
teardown()
# We use strings for precision and flatten instead of bool to
# We use strings for precision and flatten params instead of bool to
# make the pytest output more readable.
@skip_if_single_gpu
@pytest.mark.parametrize("precision", ["full", "mixed"])
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
def test1(precision, flatten):
def test1(temp_files, ddp_ref, precision, flatten):
if torch_version() < (1, 6, 0):
pytest.skip("older pytorch doesn't support reduce_scatter")
temp_file_name = tempfile.mkstemp()[1]
unused = tempfile.mkstemp()[1]
state_before, inputs, state_after_fp, state_after_mp = ddp_ref
if precision == "full":
state_after = state_after_fp
else:
state_after = state_after_mp
fsdp_config = {}
fsdp_config["mixed_precision"] = precision == "mixed"
fsdp_config["flatten_parameters"] = flatten == "flatten"
# Some bugs only show up when we are in world_size > 1 due to sharding changing
# the tensor dimensions.
world_size = 2
mp.spawn(
_test_func, args=(world_size, fsdp_config, temp_file_name, unused), nprocs=world_size, join=True,
_test_func,
args=(world_size, fsdp_config, None, temp_files[0], temp_files[1], state_before, inputs, None, state_after),
nprocs=world_size,
join=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment