Unverified Commit 961df76e authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] mypy and flaky test (#624)



* [fix] mypy and flaky test

- CI didn't seem to catch this or maybe I merged incorrectly yesterday
- this should fix the mypy error on master
- also updated a test that seems to be flaky due to tcp port conflict

* another flaky test, hopefully more determinism helps

* CR

* skip 1.6

* fix

* minor
Co-authored-by: default avatarMin Xu <min.xu@acm.org>
parent 85962b97
......@@ -1662,9 +1662,13 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group:
def wrap_bn_only_policy(module: nn.Module, recurse: bool, unwrapped_params: int) -> bool:
is_bn = isinstance(module, torch.nn.modules.batchnorm._BatchNorm)
if recurse:
return not isinstance(module, tuple(default_auto_wrap_policy.FORCE_LEAF_MODULES)) # type: ignore
return not isinstance(
module, tuple(default_auto_wrap_policy.FORCE_LEAF_MODULES) # type: ignore
)
else:
return is_bn and not isinstance(module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)) # type: ignore
return is_bn and not isinstance(
module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES) # type: ignore
)
pg = None
if single_rank_pg:
......
......@@ -9,8 +9,7 @@
""" Test FSDP with different input types. """
import os
import random
import tempfile
import pytest
import torch
......@@ -19,26 +18,40 @@ from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState
from fairscale.utils.testing import skip_if_no_cuda, torch_version
from fairscale.utils.testing import dist_init, rmf, skip_if_no_cuda, teardown, torch_version
# We only test on GPU since mix-precision only really works on GPU.
# A fixture to get tempfiles and ensure they are cleaned up.
@pytest.fixture()
def temp_files():
num = 2 # dist_init needs 2 files
files = [tempfile.mkstemp()[1] for _ in range(num)]
yield tuple(files)
# temp files could have been removed, so we use rmf.
for name in files:
rmf(name)
# We only test on GPU since mix-precision only works on GPU.
@skip_if_no_cuda
@pytest.mark.parametrize(
"fsdp_config", [{}, {"mixed_precision": True}],
)
@pytest.mark.parametrize("input_cls", [dict, list])
def test_it(fsdp_config, input_cls):
def test_input_type(temp_files, fsdp_config, input_cls):
"""Test FSDP with input being a list or a dict, only single GPU."""
if torch_version() < (1, 6, 0):
pytest.skip("older pytorch doesn't support reduce_scatter")
# Random port in case the next test run quickly, same port would cause conflict.
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(random.randint(2000, 3000))
torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)
if torch_version() < (1, 7, 0):
# This test runs multiple test cases in a single process. On 1.6.0 it
# throw an error like this:
# RuntimeError: Container is already initialized! Cannot initialize it twice!
pytest.skip("older pytorch doesn't work well with single process dist_init multiple times")
result = dist_init(rank=0, world_size=1, filename=temp_files[0], filename_rpc=temp_files[1])
assert result, "Dist init failed"
try:
assert isinstance(fsdp_config, dict), str(fsdp_config)
class Model(Module):
......@@ -73,8 +86,4 @@ def test_it(fsdp_config, input_cls):
model.assert_state(TrainingState.IDLE)
finally:
# Clean-up is important or the next test in this file may fail to init the PG.
torch.distributed.destroy_process_group()
del os.environ["MASTER_ADDR"]
del os.environ["MASTER_PORT"]
teardown()
......@@ -169,9 +169,9 @@ def ddp_ref():
unused = tempfile.mkstemp()[1]
rank_0_output = tempfile.mkstemp()[1]
try:
fsdp_config = None # This means we use DDP in _test_func.
fsdp_config = None # This means we use DDP in _distributed_worker.
mp.spawn(
_test_func,
_distributed_worker,
args=(
world_size,
fsdp_config,
......@@ -214,7 +214,7 @@ def temp_files():
rmf(unused)
def _test_func(
def _distributed_worker(
rank,
world_size,
fsdp_config,
......@@ -227,6 +227,8 @@ def _test_func(
rank_0_output,
state_after,
):
torch.backends.cudnn.deterministic = True
result = dist_init(rank, world_size, tempfile_name, unused)
assert result, "Dist init failed"
......@@ -328,7 +330,7 @@ def _test_func(
@skip_if_single_gpu
@pytest.mark.parametrize("precision", ["full", "mixed"])
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
def test1(temp_files, ddp_ref, precision, flatten):
def test_regnet(temp_files, ddp_ref, precision, flatten):
if torch_version() < (1, 6, 0):
pytest.skip("older pytorch doesn't support reduce_scatter")
......@@ -356,7 +358,7 @@ def test1(temp_files, ddp_ref, precision, flatten):
world_size = _world_size
mp.spawn(
_test_func,
_distributed_worker,
args=(
world_size,
fsdp_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment