Unverified Commit 0233efca authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[test] FSDP: check with ddp parity with conv + bn (#549)

- added DDP equivalency test
- added rmf, state_dict_norm functions to testing utils
- added more debugging output to objects_are_equal
parent a2b11de4
...@@ -400,33 +400,53 @@ class GPT2(Base): ...@@ -400,33 +400,53 @@ class GPT2(Base):
return self.clf_head(h), logits return self.clf_head(h), logits
def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool: def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: Optional[str] = None) -> bool:
""" """
Test that two objects are equal. Tensors are compared to ensure matching Test that two objects are equal. Tensors are compared to ensure matching
size, dtype, device and values. size, dtype, device and values.
""" """
if type(a) is not type(b): if type(a) is not type(b):
if raise_exception:
raise ValueError(f"type mismatch {type(a)} vs. {type(b)}")
return False return False
if isinstance(a, dict): if isinstance(a, dict):
if set(a.keys()) != set(b.keys()): if set(a.keys()) != set(b.keys()):
if raise_exception:
raise ValueError(f"keys mismatch {a.keys()} vs. {b.keys()}")
return False return False
for k in a.keys(): for k in a.keys():
if not objects_are_equal(a[k], b[k], raise_exception): if not objects_are_equal(a[k], b[k], raise_exception, k):
return False return False
return True return True
elif isinstance(a, (list, tuple, set)): elif isinstance(a, (list, tuple, set)):
if len(a) != len(b): if len(a) != len(b):
if raise_exception:
raise ValueError(f"length mismatch {len(a)} vs. {len(b)}")
return False return False
return all(objects_are_equal(x, y, raise_exception) for x, y in zip(a, b)) return all(objects_are_equal(x, y, raise_exception) for x, y in zip(a, b))
elif torch.is_tensor(a): elif torch.is_tensor(a):
try: try:
torch.testing.assert_allclose(a, b)
# assert_allclose doesn't strictly test shape, dtype and device # assert_allclose doesn't strictly test shape, dtype and device
shape_dtype_device_match = a.size() == b.size() and a.dtype == b.dtype and a.device == b.device shape_dtype_device_match = a.size() == b.size() and a.dtype == b.dtype and a.device == b.device
assert shape_dtype_device_match if not shape_dtype_device_match:
if raise_exception:
msg = f"sizes: {a.size()} vs. {b.size()}, "
msg += f"types: {a.dtype} vs. {b.dtype}, "
msg += f"device: {a.device} vs. {b.device}"
raise AssertionError(msg)
else:
return False
# assert_allclose.
torch.testing.assert_allclose(a, b)
return True return True
except (AssertionError, RuntimeError) as e: except (AssertionError, RuntimeError) as e:
if raise_exception: if raise_exception:
if dict_key and isinstance(e, AssertionError):
# Add dict key to the assertion error.
msg = e.args[0]
new_msg = f"For dict key '{dict_key}': {msg}"
raise AssertionError(new_msg)
else:
raise e raise e
else: else:
return False return False
...@@ -582,3 +602,21 @@ class SGDWithPausingCompute(torch.optim.SGD): ...@@ -582,3 +602,21 @@ class SGDWithPausingCompute(torch.optim.SGD):
param *= 1.0 + self.rank / 10.0 param *= 1.0 + self.rank / 10.0
return loss return loss
def state_dict_norm(state: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute the norm from a state_dict for simple comparison."""
norm = torch.zeros(1)
for v in state.values():
if not v.is_floating_point():
v = v.float()
norm += v.norm()
return norm
def rmf(filename: str) -> None:
"""Remove a file like rm -f."""
try:
os.remove(filename)
except FileNotFoundError:
pass
...@@ -922,7 +922,7 @@ class Tensor: ...@@ -922,7 +922,7 @@ class Tensor:
# TODO: fill in the types for these, or otherwise figure out some # TODO: fill in the types for these, or otherwise figure out some
# way to not have to write these out again... # way to not have to write these out again...
def nonzero(self, *, as_tuple=True): ... def nonzero(self, *, as_tuple=True): ...
def norm(self, p="fro", dim=None, keepdim=False): ... def norm(self, p="fro", dim=None, keepdim=False, out=None, dtype=None) -> Tensor: ...
def stft(self, n_fft, hop_length=None, win_length=None, window=None, def stft(self, n_fft, hop_length=None, win_length=None, window=None,
center=True, pad_mode='reflect', normalized=False, onesided=True): ... center=True, pad_mode='reflect', normalized=False, onesided=True): ...
def split(self, split_size, dim=0) -> Tuple[Tensor, ...]: ... def split(self, split_size, dim=0) -> Tuple[Tensor, ...]: ...
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
""" Test FSDP with regnet-like model. """ """ Test FSDP with regnet-like model. """
import contextlib
import random import random
import tempfile import tempfile
...@@ -16,20 +17,23 @@ import pytest ...@@ -16,20 +17,23 @@ import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.nn import BatchNorm2d, Conv2d, Module, SyncBatchNorm from torch.nn import BatchNorm2d, Conv2d, Module, SyncBatchNorm
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, torch_version from fairscale.utils.testing import (
dist_init,
objects_are_equal,
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): rmf,
result = dist_init(rank, world_size, tempfile_name, unused) skip_if_single_gpu,
assert result, "Dist init failed" state_dict_norm,
teardown,
assert isinstance(fsdp_config, dict), str(fsdp_config) torch_version,
)
class Model(Module):
class Model(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
# TODO (Min): for now, we just test pytorch sync_bn here. # TODO (Min): for now, we just test pytorch sync_bn here.
...@@ -42,9 +46,112 @@ def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): ...@@ -42,9 +46,112 @@ def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
x = self.bn(x) x = self.bn(x)
return x return x
# TODO (Min): check DDP equivalency.
# We get a bit fancy here. Since the scope is `module`, this is run only
# once no matter how many tests variations for FSDP are requested to run
# to compare with the DDP reference. For example, a single DDP
# reference run is needed for both flatten and non-flatten param FSDP.
#
# Note, this runs DDP twice with and without mixed precision and asserts
# the resulting weights are different.
#
# This fixture captures and returns:
#
# - model state_dict before training
# - model data inputs
# - model state_dict after training
@pytest.fixture(scope="module")
def ddp_ref():
# Get a reference model state
model = Model() model = Model()
state_before = model.state_dict()
# Get reference inputs per rank.
world_size = 2
iterations = 100
inputs = [[]] * world_size
for rank in range(world_size):
for i in range(iterations):
inputs[rank].append(torch.rand(2, 2, 2, 2))
# Run DDP training twice, fp and mp.
for precision in ["full", "mixed"]:
temp_file_name = tempfile.mkstemp()[1]
unused = tempfile.mkstemp()[1]
rank_0_output = tempfile.mkstemp()[1]
try:
fsdp_config = None # This means we use DDP in _test_func.
mp.spawn(
_test_func,
args=(
world_size,
fsdp_config,
precision == "mixed",
temp_file_name,
unused,
state_before,
inputs,
rank_0_output,
None,
),
nprocs=world_size,
join=True,
)
if precision == "full":
state_after_fp = torch.load(rank_0_output)
else:
state_after_mp = torch.load(rank_0_output)
finally:
rmf(temp_file_name)
rmf(unused)
rmf(rank_0_output)
assert state_dict_norm(state_after_fp) != state_dict_norm(state_after_mp)
return state_before, inputs, state_after_fp, state_after_mp
# A fixture to get tempfiles and ensure they are cleaned up.
@pytest.fixture()
def temp_files():
temp_file_name = tempfile.mkstemp()[1]
unused = tempfile.mkstemp()[1]
yield temp_file_name, unused
# temp files could have been removed, so we use rmf.
rmf(temp_file_name)
rmf(unused)
def _test_func(
rank,
world_size,
fsdp_config,
ddp_mixed_precision,
tempfile_name,
unused,
state_before,
inputs,
rank_0_output,
state_after,
):
result = dist_init(rank, world_size, tempfile_name, unused)
assert result, "Dist init failed"
ddp = True
if fsdp_config:
ddp = False
assert isinstance(fsdp_config, dict), str(fsdp_config)
model = Model()
model.load_state_dict(state_before)
model = model.cuda()
if ddp:
model = SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[rank])
else:
# Note, different rank may wrap in different order due to different random # Note, different rank may wrap in different order due to different random
# seeds. But results should be the same. # seeds. But results should be the same.
if random.randint(0, 1) == 0: if random.randint(0, 1) == 0:
...@@ -58,37 +165,60 @@ def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): ...@@ -58,37 +165,60 @@ def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
model = FSDP(model, **fsdp_config).cuda() model = FSDP(model, **fsdp_config).cuda()
optim = SGD(model.parameters(), lr=0.1) optim = SGD(model.parameters(), lr=0.1)
for _ in range(3): for in_data in inputs[rank]:
in_data = torch.rand(2, 2, 2, 2).cuda() in_data = in_data.cuda()
in_data.requires_grad = True context = contextlib.suppress()
if ddp and ddp_mixed_precision:
in_data = in_data.half()
context = torch.cuda.amp.autocast(enabled=True)
with context:
out = model(in_data) out = model(in_data)
out.sum().backward() loss = out.sum()
loss.backward()
optim.step() optim.step()
optim.zero_grad() optim.zero_grad()
if ddp:
# Save the rank 0 state_dict to the output file.
if rank == 0:
state_after = model.module.cpu().state_dict()
torch.save(state_after, rank_0_output)
else:
model.assert_state(TrainingState.IDLE) model.assert_state(TrainingState.IDLE)
# Ensure final state equals to the state_after.
fsdp_state = model.state_dict()
# Move tensors to CPU to compare numerics.
for k, v in fsdp_state.items():
fsdp_state[k] = v.cpu()
assert objects_are_equal(state_after, fsdp_state, raise_exception=True)
teardown() teardown()
# We use strings for precision and flatten instead of bool to # We use strings for precision and flatten params instead of bool to
# make the pytest output more readable. # make the pytest output more readable.
@skip_if_single_gpu @skip_if_single_gpu
@pytest.mark.parametrize("precision", ["full", "mixed"]) @pytest.mark.parametrize("precision", ["full", "mixed"])
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"]) @pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
def test1(precision, flatten): def test1(temp_files, ddp_ref, precision, flatten):
if torch_version() < (1, 6, 0): if torch_version() < (1, 6, 0):
pytest.skip("older pytorch doesn't support reduce_scatter") pytest.skip("older pytorch doesn't support reduce_scatter")
temp_file_name = tempfile.mkstemp()[1] state_before, inputs, state_after_fp, state_after_mp = ddp_ref
unused = tempfile.mkstemp()[1]
if precision == "full":
state_after = state_after_fp
else:
state_after = state_after_mp
fsdp_config = {} fsdp_config = {}
fsdp_config["mixed_precision"] = precision == "mixed" fsdp_config["mixed_precision"] = precision == "mixed"
fsdp_config["flatten_parameters"] = flatten == "flatten" fsdp_config["flatten_parameters"] = flatten == "flatten"
# Some bugs only show up when we are in world_size > 1 due to sharding changing
# the tensor dimensions.
world_size = 2 world_size = 2
mp.spawn( mp.spawn(
_test_func, args=(world_size, fsdp_config, temp_file_name, unused), nprocs=world_size, join=True, _test_func,
args=(world_size, fsdp_config, None, temp_files[0], temp_files[1], state_before, inputs, None, state_after),
nprocs=world_size,
join=True,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment