"src/vscode:/vscode.git/clone" did not exist on "7101c7316b6f6d3f4e578f29c108533cb678a304"
Unverified Commit 21cba91b authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[test] improve BN test coverage (#638)



* [test] improve BN test coverage

- Added sync_bn on/off cases
- Added conv and linear bias on/off cases
- clarified when sync_bn is off, when is BN wrapping needed with the test

* adding a comment
Co-authored-by: default avatarMin Xu <min.xu@acm.org>
parent bdc0581b
......@@ -10,6 +10,7 @@
""" Test FSDP with regnet-like model. """
import contextlib
from itertools import product
import random
import tempfile
......@@ -52,8 +53,8 @@ from fairscale.utils.testing import (
_world_size = 2
_iterations = 5
# Cover different ReLU flavor. This will cause DDP and FSDP models to have
# different ReLUs since they will different random flags.
# Cover different ReLU flavors. Different workers may have different values since
# this is a file level global. This is intensional to cover different behaviors.
_relu_inplace = True
if random.randint(0, 1) == 0:
_relu_inplace = False
......@@ -66,7 +67,6 @@ try:
except ImportError:
apex_bn_converter = None
pytorch_bn_converter = SyncBatchNorm.convert_sync_batchnorm # type: ignore
_bn_converter = pytorch_bn_converter
_single_rank_pg = False
......@@ -109,23 +109,19 @@ class ResBlock(Module):
class Model(Module):
"""SSL model with trunk and head."""
def __init__(self):
def __init__(self, conv_bias, linear_bias):
super().__init__()
print(f"Using relu inplace: {_relu_inplace}")
print(f"relu inplace: {_relu_inplace}, conv bias: {conv_bias}, linear bias: {linear_bias}")
self.trunk = Sequential()
self.trunk.need_fsdp_wrap = True # Set a flag for later wrapping.
stem = Sequential(Conv2d(2, 4, (3, 3), (2, 2), (1, 1), bias=False), BatchNorm2d(4), ReLU(_relu_inplace))
stem = Sequential(Conv2d(2, 4, (3, 3), (2, 2), (1, 1), bias=conv_bias), BatchNorm2d(4), ReLU(_relu_inplace))
any_stage_block1_0 = ResBlock(4, 8)
self.trunk.add_module("stem", stem)
self.trunk.add_module("any_stage_block1", Sequential(any_stage_block1_0))
self.head = Sequential(
# TODO (Min): FSDP-mixed_precision doesn't compute the same ways as DDP AMP when bias=True.
# so, we use bias=False for now in the projection_head.
# The Conv2d layers above does not use bias in regnet, but even if they use
# bias, FSDP and DDP seem to agree on how it is computed.
Sequential(Linear(16, 16, bias=False), ReLU(), Linear(16, 8, bias=False),), # projection_head
Sequential(Linear(16, 16, bias=linear_bias), ReLU(), Linear(16, 8, bias=linear_bias)), # projection_head
Linear(8, 15, bias=False), # prototypes0
)
......@@ -150,8 +146,17 @@ class Model(Module):
# - model state_dict after training
@pytest.fixture(scope="module")
def ddp_ref():
# Cover different bias flavors. Use random instead of parameterize them to reduce
# the test runtime. Otherwise, we would have covered all cases exhaustively.
conv_bias = True
if random.randint(0, 1) == 0:
conv_bias = False
linear_bias = True
if random.randint(0, 1) == 0:
linear_bias = False
# Get a reference model state
model = Model()
model = Model(conv_bias, linear_bias)
state_before = model.state_dict()
# Get reference inputs per rank.
......@@ -163,8 +168,9 @@ def ddp_ref():
for i in range(iterations):
inputs[rank].append(torch.rand(2, 2, 2, 2))
# Run DDP training twice, fp and mp.
for precision in ["full", "mixed"]:
# Run reference DDP training 4 times, fp and mp, sync_bn or not.
state_after = {}
for precision, sync_bn in product(["full", "mixed"], ["none", "pytorch"]):
temp_file_name = tempfile.mkstemp()[1]
unused = tempfile.mkstemp()[1]
rank_0_output = tempfile.mkstemp()[1]
......@@ -183,22 +189,25 @@ def ddp_ref():
inputs,
rank_0_output,
None,
sync_bn,
conv_bias,
linear_bias,
),
nprocs=world_size,
join=True,
)
if precision == "full":
state_after_fp = torch.load(rank_0_output)
else:
state_after_mp = torch.load(rank_0_output)
state_after[(precision, sync_bn)] = torch.load(rank_0_output)
finally:
rmf(temp_file_name)
rmf(unused)
rmf(rank_0_output)
assert state_dict_norm(state_after_fp) != state_dict_norm(state_after_mp)
# Sanity check DDP's final states.
states = list(state_after.values())
for state in states[1:]:
assert state_dict_norm(states[0]) != state_dict_norm(state)
return state_before, inputs, state_after_fp, state_after_mp
return state_before, inputs, conv_bias, linear_bias, state_after
# A fixture to get tempfiles and ensure they are cleaned up.
......@@ -226,6 +235,9 @@ def _distributed_worker(
inputs,
rank_0_output,
state_after,
sync_bn,
conv_bias,
linear_bias,
):
torch.backends.cudnn.deterministic = True
......@@ -240,7 +252,7 @@ def _distributed_worker(
# To match DDP in AMP -O1, we need fp32 reduce scatter.
fsdp_config["fp32_reduce_scatter"] = True
model = Model()
model = Model(conv_bias, linear_bias)
model.load_state_dict(state_before)
model = model.cuda()
......@@ -256,7 +268,8 @@ def _distributed_worker(
scaler = DummyScaler()
if ddp:
model = SyncBatchNorm.convert_sync_batchnorm(model)
if sync_bn == "pytorch":
model = pytorch_bn_converter(model)
model = DDP(model, device_ids=[rank], broadcast_buffers=True)
if ddp_mixed_precision:
scaler = GradScaler()
......@@ -264,13 +277,15 @@ def _distributed_worker(
# Note, different rank may wrap in different order due to different random
# seeds. But results should be the same.
if random.randint(0, 1) == 0:
print(f"auto_wrap_bn {fsdp_wrap_bn}, then convert_sync_batchnorm")
print(f"auto_wrap_bn {fsdp_wrap_bn}, then sync_bn {sync_bn}")
if fsdp_wrap_bn:
model = auto_wrap_bn(model, _single_rank_pg)
model = _bn_converter(model)
if sync_bn == "pytorch":
model = pytorch_bn_converter(model)
else:
print(f"convert_sync_batchnorm, then auto_wrap_bn {fsdp_wrap_bn}")
model = _bn_converter(model)
print(f"sync_bn {sync_bn}, then auto_wrap_bn {fsdp_wrap_bn}")
if sync_bn == "pytorch":
model = pytorch_bn_converter(model)
if fsdp_wrap_bn:
model = auto_wrap_bn(model, _single_rank_pg)
model = FSDP(model, **fsdp_config).cuda()
......@@ -320,7 +335,10 @@ def _distributed_worker(
dump(state_after)
dump(fsdp_state)
assert objects_are_equal(state_after, fsdp_state, raise_exception=True)
# If sync_bn is used, all ranks should have the same state, so we can compare with
# rank 0 state on every rank. Otherwise, only compare rank 0 with rank 0.
if sync_bn != "none" or rank == 0:
assert objects_are_equal(state_after, fsdp_state, raise_exception=True)
teardown()
......@@ -330,32 +348,42 @@ def _distributed_worker(
@skip_if_single_gpu
@pytest.mark.parametrize("precision", ["full", "mixed"])
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
def test_regnet(temp_files, ddp_ref, precision, flatten):
@pytest.mark.parametrize("sync_bn", ["none", "pytorch"])
def test_regnet(temp_files, ddp_ref, precision, flatten, sync_bn):
if torch_version() < (1, 6, 0):
pytest.skip("older pytorch doesn't support reduce_scatter")
state_before, inputs, state_after_fp, state_after_mp = ddp_ref
state_before, inputs, conv_bias, linear_bias, state_after = ddp_ref
if precision == "full":
state_after = state_after_fp
else:
state_after = state_after_mp
state_after = state_after[(precision, sync_bn)]
fsdp_config = {}
fsdp_config["mixed_precision"] = precision == "mixed"
fsdp_config["flatten_parameters"] = flatten == "flatten"
# When linear bias is True, DDP's AMP O1 and FSDP's default AMP O1.5 is different,
# we force FSDP to use AMP O1 here by setting compute_dtype to float32.
if linear_bias:
fsdp_config["compute_dtype"] = torch.float32
if fsdp_config["mixed_precision"] and torch_cuda_version() < (11, 0):
pytest.skip("Only CUDA 11 is supported with AMP equivalency")
# Wrap BN half of the time in full precision mode.
# Wrap BN half of the time.
wrap_bn = True
if random.randint(0, 1) == 0:
wrap_bn = False
# Always wrap BN in mixed precision mode.
if fsdp_config["mixed_precision"]:
# Except, always wrap BN in mixed precision + sync_bn mode, due to error of sync_bn wrapping,
# regardless of compute_dtype.
if fsdp_config["mixed_precision"] and sync_bn != "none":
wrap_bn = True
# When BN is not wrapped (i.e. not in full precision), FSDP's compute_dtype needs to
# be fp32 to match DDP (otherwise, numerical errors happen on BN's running_mean/running_var
# buffers).
if fsdp_config["mixed_precision"] and not wrap_bn:
fsdp_config["compute_dtype"] = torch.float32
world_size = _world_size
mp.spawn(
_distributed_worker,
......@@ -370,6 +398,9 @@ def test_regnet(temp_files, ddp_ref, precision, flatten):
inputs,
None,
state_after,
sync_bn,
conv_bias,
linear_bias,
),
nprocs=world_size,
join=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment