Unverified Commit 7d7edf6d authored by Anupam Bhatnagar's avatar Anupam Bhatnagar Committed by GitHub
Browse files

Setup pre-commit github action and apply pre-commit to all files (#849)

* adding pre-commit files

* applying pre-commit to all files

* adding no-strict-optional argument to mypy in circle ci config

* fix typo

* updating python versions

* [skip ci] remove extra args

* adding python 3.9

* [skip ci] set pre-commit version in requirements-dev.txt

* set CACHE_VERSION

* move linters from circleci to github actions

* update python version

* update python version in benchmarks_2

* moving to python 3.9.7
parent 6f3931a4
...@@ -263,11 +263,19 @@ def test_deprecated_path(): ...@@ -263,11 +263,19 @@ def test_deprecated_path():
# from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper # from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper
from fairscale.nn import checkpoint_wrapper from fairscale.nn import checkpoint_wrapper
ffn = nn.Sequential(nn.Linear(32, 128), nn.Dropout(p=0.5), nn.Linear(128, 32),) ffn = nn.Sequential(
nn.Linear(32, 128),
nn.Dropout(p=0.5),
nn.Linear(128, 32),
)
ffn = checkpoint_wrapper(ffn, {}) ffn = checkpoint_wrapper(ffn, {})
# Check if direct import works as before. # Check if direct import works as before.
ffn = nn.Sequential(nn.Linear(32, 128), nn.Dropout(p=0.5), nn.Linear(128, 32),) ffn = nn.Sequential(
nn.Linear(32, 128),
nn.Dropout(p=0.5),
nn.Linear(128, 32),
)
ffn = deprecated_checkpoint_wrapper(ffn, {}) ffn = deprecated_checkpoint_wrapper(ffn, {})
......
...@@ -83,7 +83,16 @@ class DistributedTest(unittest.TestCase): ...@@ -83,7 +83,16 @@ class DistributedTest(unittest.TestCase):
@classmethod @classmethod
def _test_identical_outputs( def _test_identical_outputs(
cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2, cls,
model_init_fn,
config,
rank,
group,
num_steps=2,
use_cuda=True,
lr=0.01,
ref_ddp_fn=None,
norm_type=2,
): ):
if config.get("mixed_precision", False): if config.get("mixed_precision", False):
autocast = True autocast = True
...@@ -265,7 +274,10 @@ CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([Tru ...@@ -265,7 +274,10 @@ CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([Tru
def rename_test(testcase_func, param_num, param): def rename_test(testcase_func, param_num, param):
return "%s_%s" % (testcase_func.__name__, parameterized.to_safe_name(str(param.args)),) return "%s_%s" % (
testcase_func.__name__,
parameterized.to_safe_name(str(param.args)),
)
class TestComparisonToPyTorchDDP(DistributedTest): class TestComparisonToPyTorchDDP(DistributedTest):
...@@ -373,7 +385,11 @@ class TestComparisonToPyTorchDDP(DistributedTest): ...@@ -373,7 +385,11 @@ class TestComparisonToPyTorchDDP(DistributedTest):
def test_mixture_of_experts_grad_clip_breaks(self): def test_mixture_of_experts_grad_clip_breaks(self):
config = {"mixed_precision": True} config = {"mixed_precision": True}
test_fn = functools.partial( test_fn = functools.partial(
self._test_identical_outputs, MixtureOfExperts, config, ref_ddp_fn=self._dummy_ddp_fn, norm_type=2, self._test_identical_outputs,
MixtureOfExperts,
config,
ref_ddp_fn=self._dummy_ddp_fn,
norm_type=2,
) )
with self.assertRaises(Exception): with self.assertRaises(Exception):
spawn_and_init(test_fn) spawn_and_init(test_fn)
...@@ -386,7 +402,10 @@ class TestComparisonToPyTorchDDP(DistributedTest): ...@@ -386,7 +402,10 @@ class TestComparisonToPyTorchDDP(DistributedTest):
def test_clip_norm_transformer(self, norm_type): def test_clip_norm_transformer(self, norm_type):
config = {"mixed_precision": True} config = {"mixed_precision": True}
test_fn = functools.partial( test_fn = functools.partial(
self._test_identical_outputs, TransformerWithSharedParams, config, norm_type=norm_type, self._test_identical_outputs,
TransformerWithSharedParams,
config,
norm_type=norm_type,
) )
spawn_and_init(test_fn) spawn_and_init(test_fn)
...@@ -593,7 +612,11 @@ class TransformerWithSharedParams(nn.Module): ...@@ -593,7 +612,11 @@ class TransformerWithSharedParams(nn.Module):
assert d_vocab >= 12 # we use torch.arange(12) as input assert d_vocab >= 12 # we use torch.arange(12) as input
self.embed_tokens = nn.Embedding(d_vocab, d_model) self.embed_tokens = nn.Embedding(d_vocab, d_model)
self.transformer = nn.Transformer( self.transformer = nn.Transformer(
d_model=d_model, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=8, dropout=0.1, d_model=d_model,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=8,
dropout=0.1,
) )
self.output_proj = nn.Linear(d_model, d_vocab) self.output_proj = nn.Linear(d_model, d_vocab)
...@@ -642,7 +665,12 @@ class NestedWrappedModule(nn.Module): ...@@ -642,7 +665,12 @@ class NestedWrappedModule(nn.Module):
torch.manual_seed(0) # keep everything deterministic torch.manual_seed(0) # keep everything deterministic
self.module = nn.Sequential( self.module = nn.Sequential(
nn.Linear(8, 4), nn.Linear(8, 4),
_maybe_wrap(nn.Sequential(_maybe_wrap(nn.Linear(4, 16)), nn.Linear(16, 16),)), _maybe_wrap(
nn.Sequential(
_maybe_wrap(nn.Linear(4, 16)),
nn.Linear(16, 16),
)
),
_maybe_wrap(nn.Linear(16, 4)), _maybe_wrap(nn.Linear(16, 4)),
nn.Linear(4, 8), nn.Linear(4, 8),
) )
......
...@@ -43,7 +43,10 @@ class FreezeModel(nn.Module): ...@@ -43,7 +43,10 @@ class FreezeModel(nn.Module):
def _freeze_distributed_worker( def _freeze_distributed_worker(
gpu_id, world_size, tempfile_name, unused, gpu_id,
world_size,
tempfile_name,
unused,
): ):
torch.cuda.set_device(gpu_id) torch.cuda.set_device(gpu_id)
...@@ -88,7 +91,9 @@ def _freeze_distributed_worker( ...@@ -88,7 +91,9 @@ def _freeze_distributed_worker(
def test_submodule_freezing_weights(temp_files): def test_submodule_freezing_weights(temp_files):
world_size = 2 world_size = 2
mp.spawn( mp.spawn(
_freeze_distributed_worker, (world_size, temp_files[0], temp_files[1]), nprocs=world_size, _freeze_distributed_worker,
(world_size, temp_files[0], temp_files[1]),
nprocs=world_size,
) )
...@@ -120,7 +125,11 @@ class NestedTrunkModel(nn.Module): ...@@ -120,7 +125,11 @@ class NestedTrunkModel(nn.Module):
self._create_block(3, 64, with_fsdp, freeze_after_wrap_fsdp), self._create_block(3, 64, with_fsdp, freeze_after_wrap_fsdp),
self._create_block(64, 64, with_fsdp, freeze_after_wrap_fsdp), self._create_block(64, 64, with_fsdp, freeze_after_wrap_fsdp),
) )
self.head = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(64, 10),) self.head = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(),
nn.Linear(64, 10),
)
if with_fsdp and freeze_after_wrap_fsdp: if with_fsdp and freeze_after_wrap_fsdp:
self.fsdp_wrap() self.fsdp_wrap()
...@@ -135,7 +144,10 @@ class NestedTrunkModel(nn.Module): ...@@ -135,7 +144,10 @@ class NestedTrunkModel(nn.Module):
return self.head(self.trunk(x)) return self.head(self.trunk(x))
def _create_block(self, in_channels, out_channels, with_fsdp, freeze_after_wrap_fsdp): def _create_block(self, in_channels, out_channels, with_fsdp, freeze_after_wrap_fsdp):
block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3), nn.ReLU(inplace=True),) block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3),
nn.ReLU(inplace=True),
)
return block return block
......
...@@ -38,7 +38,8 @@ def temp_files(): ...@@ -38,7 +38,8 @@ def temp_files():
# We only test on GPU since mix-precision only works on GPU. # We only test on GPU since mix-precision only works on GPU.
@skip_if_no_cuda @skip_if_no_cuda
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fsdp_config", [{}, {"mixed_precision": True}], "fsdp_config",
[{}, {"mixed_precision": True}],
) )
@pytest.mark.parametrize("input_cls", [dict, list]) @pytest.mark.parametrize("input_cls", [dict, list])
def test_input_type(temp_files, fsdp_config, input_cls): def test_input_type(temp_files, fsdp_config, input_cls):
......
...@@ -43,7 +43,9 @@ class ConvolutionalModel(nn.Module): ...@@ -43,7 +43,9 @@ class ConvolutionalModel(nn.Module):
@staticmethod @staticmethod
def _conv_block(in_channels: int, out_channels: int): def _conv_block(in_channels: int, out_channels: int):
return nn.Sequential( return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3)), nn.BatchNorm2d(out_channels), nn.ReLU(), nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3)),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
) )
def forward(self, x): def forward(self, x):
...@@ -78,7 +80,10 @@ def _worker(gpu_id: int, sync_file: str, world_size: int, embedding_size: int, f ...@@ -78,7 +80,10 @@ def _worker(gpu_id: int, sync_file: str, world_size: int, embedding_size: int, f
torch.manual_seed(0) torch.manual_seed(0)
torch.cuda.set_device(gpu_id) torch.cuda.set_device(gpu_id)
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend="nccl", init_method=f"file://{sync_file}", world_size=world_size, rank=gpu_id, backend="nccl",
init_method=f"file://{sync_file}",
world_size=world_size,
rank=gpu_id,
) )
process_group = torch.distributed.new_group() process_group = torch.distributed.new_group()
...@@ -116,7 +121,8 @@ def _worker(gpu_id: int, sync_file: str, world_size: int, embedding_size: int, f ...@@ -116,7 +121,8 @@ def _worker(gpu_id: int, sync_file: str, world_size: int, embedding_size: int, f
# Reconstruct a full checkpoint from the sharded checkpoints # Reconstruct a full checkpoint from the sharded checkpoints
all_checkpoints = [_load_sharded_checkpoint(rank) for rank in range(world_size)] all_checkpoints = [_load_sharded_checkpoint(rank) for rank in range(world_size)]
consolidated_checkpoint = FullyShardedDataParallel.consolidate_shard_weights( consolidated_checkpoint = FullyShardedDataParallel.consolidate_shard_weights(
shard_weights=[c["weights"] for c in all_checkpoints], shard_metadata=[c["meta"] for c in all_checkpoints], shard_weights=[c["weights"] for c in all_checkpoints],
shard_metadata=[c["meta"] for c in all_checkpoints],
) )
# Check that the reconstructed parameters are correct and of the right shape # Check that the reconstructed parameters are correct and of the right shape
...@@ -159,7 +165,8 @@ def test_consolidation(embedding_size: int, flatten_parameters: bool): ...@@ -159,7 +165,8 @@ def test_consolidation(embedding_size: int, flatten_parameters: bool):
@skip_if_single_gpu @skip_if_single_gpu
class TestConsolidatedWeights(DistributedTest): class TestConsolidatedWeights(DistributedTest):
@parameterized.expand( @parameterized.expand(
[[True], [False]], name_func=rename_test, [[True], [False]],
name_func=rename_test,
) )
def test_consolidate_weights(self, transformer): def test_consolidate_weights(self, transformer):
config = {"mixed_precision": True, "flatten_parameters": True, "compute_dtype": torch.float32} config = {"mixed_precision": True, "flatten_parameters": True, "compute_dtype": torch.float32}
...@@ -186,7 +193,10 @@ class TestConsolidatedWeights(DistributedTest): ...@@ -186,7 +193,10 @@ class TestConsolidatedWeights(DistributedTest):
else: else:
fsdp = FullyShardedDataParallel(MixtureOfExperts(group, wrapper_config=config)).cuda() fsdp = FullyShardedDataParallel(MixtureOfExperts(group, wrapper_config=config)).cuda()
optim = Adam(fsdp.parameters(), lr=0.01,) optim = Adam(
fsdp.parameters(),
lr=0.01,
)
optim.zero_grad() optim.zero_grad()
with torch.cuda.amp.autocast(enabled=True): with torch.cuda.amp.autocast(enabled=True):
x = fsdp.module.get_input(torch.device("cuda")) x = fsdp.module.get_input(torch.device("cuda"))
...@@ -207,7 +217,8 @@ class TestConsolidatedWeights(DistributedTest): ...@@ -207,7 +217,8 @@ class TestConsolidatedWeights(DistributedTest):
return return
all_checkpoints = [torch.load(p) for p in paths] all_checkpoints = [torch.load(p) for p in paths]
consolidated_checkpoint = FullyShardedDataParallel.consolidate_shard_weights( consolidated_checkpoint = FullyShardedDataParallel.consolidate_shard_weights(
shard_weights=[c["weights"] for c in all_checkpoints], shard_metadata=[c["meta"] for c in all_checkpoints], shard_weights=[c["weights"] for c in all_checkpoints],
shard_metadata=[c["meta"] for c in all_checkpoints],
) )
full_model_extra = set(full_model_state_dict).difference(set(consolidated_checkpoint)) full_model_extra = set(full_model_state_dict).difference(set(consolidated_checkpoint))
consolidated_extra = set(consolidated_checkpoint).difference(set(full_model_state_dict)) consolidated_extra = set(consolidated_checkpoint).difference(set(full_model_state_dict))
......
...@@ -76,5 +76,8 @@ def test1(precision, flatten): ...@@ -76,5 +76,8 @@ def test1(precision, flatten):
# the tensor dimensions. # the tensor dimensions.
world_size = 2 world_size = 2
mp.spawn( mp.spawn(
_test_func, args=(world_size, fsdp_config, temp_file_name, unused), nprocs=world_size, join=True, _test_func,
args=(world_size, fsdp_config, temp_file_name, unused),
nprocs=world_size,
join=True,
) )
...@@ -247,7 +247,7 @@ def _get_cached_results( ...@@ -247,7 +247,7 @@ def _get_cached_results(
fp32_reduce_scatter, fp32_reduce_scatter,
bucket_cap_mb, bucket_cap_mb,
): ):
""" Cache the training to save time. For DDP, flatten, wrap_bn etc. doesn't matter, so """Cache the training to save time. For DDP, flatten, wrap_bn etc. doesn't matter, so
the results can be cached. the results can be cached.
""" """
if not with_fsdp: if not with_fsdp:
......
...@@ -32,7 +32,9 @@ def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): ...@@ -32,7 +32,9 @@ def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
class InnerModel(Module): class InnerModel(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.layers = Sequential(FSDP(Linear(5, 5), **fsdp_config),) self.layers = Sequential(
FSDP(Linear(5, 5), **fsdp_config),
)
def forward(self, x): def forward(self, x):
return self.layers(x) return self.layers(x)
...@@ -85,5 +87,8 @@ def test(world_size, precision, flatten): ...@@ -85,5 +87,8 @@ def test(world_size, precision, flatten):
} }
mp.spawn( mp.spawn(
_test_func, args=(world_size, fsdp_config, temp_file_name, unused), nprocs=world_size, join=True, _test_func,
args=(world_size, fsdp_config, temp_file_name, unused),
nprocs=world_size,
join=True,
) )
...@@ -67,7 +67,15 @@ class DistributedTest(unittest.TestCase): ...@@ -67,7 +67,15 @@ class DistributedTest(unittest.TestCase):
@classmethod @classmethod
def _test_identical_outputs_eval( def _test_identical_outputs_eval(
cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, cls,
model_init_fn,
config,
rank,
group,
num_steps=2,
use_cuda=True,
lr=0.01,
ref_ddp_fn=None,
): ):
if config.get("mixed_precision", False): if config.get("mixed_precision", False):
autocast = True autocast = True
...@@ -116,7 +124,10 @@ CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([Tru ...@@ -116,7 +124,10 @@ CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([Tru
def rename_test(testcase_func, param_num, param): def rename_test(testcase_func, param_num, param):
return "%s_%s" % (testcase_func.__name__, parameterized.to_safe_name(str(param.args)),) return "%s_%s" % (
testcase_func.__name__,
parameterized.to_safe_name(str(param.args)),
)
class TestSsdMemory(DistributedTest): class TestSsdMemory(DistributedTest):
...@@ -284,7 +295,11 @@ class TransformerWithSharedParams(nn.Module): ...@@ -284,7 +295,11 @@ class TransformerWithSharedParams(nn.Module):
assert d_vocab >= 12 # we use torch.arange(12) as input assert d_vocab >= 12 # we use torch.arange(12) as input
self.embed_tokens = nn.Embedding(d_vocab, d_model) self.embed_tokens = nn.Embedding(d_vocab, d_model)
self.transformer = nn.Transformer( self.transformer = nn.Transformer(
d_model=d_model, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=8, dropout=0.1, d_model=d_model,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=8,
dropout=0.1,
) )
self.output_proj = nn.Linear(d_model, d_vocab) self.output_proj = nn.Linear(d_model, d_vocab)
...@@ -333,7 +348,12 @@ class NestedWrappedModule(nn.Module): ...@@ -333,7 +348,12 @@ class NestedWrappedModule(nn.Module):
torch.manual_seed(0) # keep everything deterministic torch.manual_seed(0) # keep everything deterministic
self.module = nn.Sequential( self.module = nn.Sequential(
nn.Linear(8, 4), nn.Linear(8, 4),
_maybe_wrap(nn.Sequential(_maybe_wrap(nn.Linear(4, 16)), nn.Linear(16, 16),)), _maybe_wrap(
nn.Sequential(
_maybe_wrap(nn.Linear(4, 16)),
nn.Linear(16, 16),
)
),
_maybe_wrap(nn.Linear(16, 4)), _maybe_wrap(nn.Linear(16, 4)),
nn.Linear(4, 8), nn.Linear(4, 8),
) )
......
...@@ -62,7 +62,10 @@ class TestOptimizerUtils(DistributedTest): ...@@ -62,7 +62,10 @@ class TestOptimizerUtils(DistributedTest):
fsdp = FullyShardedDataParallel(MixtureOfExperts(group, wrapper_config=config)).cuda() fsdp = FullyShardedDataParallel(MixtureOfExperts(group, wrapper_config=config)).cuda()
try: try:
fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,) fsdp_optim = optim_fn(
fsdp.parameters(),
lr=0.01,
)
optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01) optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01)
except TypeError: # Adadelta except TypeError: # Adadelta
fsdp_optim = optim_fn(fsdp.parameters()) fsdp_optim = optim_fn(fsdp.parameters())
......
...@@ -82,7 +82,11 @@ class Min10: ...@@ -82,7 +82,11 @@ class Min10:
def _distributed_worker( def _distributed_worker(
gpu_id, world_size, fsdp_config, tempfile, tempfile_rpc, gpu_id,
world_size,
fsdp_config,
tempfile,
tempfile_rpc,
): ):
torch.cuda.set_device(gpu_id) torch.cuda.set_device(gpu_id)
...@@ -252,5 +256,7 @@ def test_forward_overlap(world_size, flatten, mixed): ...@@ -252,5 +256,7 @@ def test_forward_overlap(world_size, flatten, mixed):
} }
with temp_files_ctx(2) as temp_files: with temp_files_ctx(2) as temp_files:
mp.spawn( mp.spawn(
_distributed_worker, (world_size, fsdp_config, temp_files[0], temp_files[1]), nprocs=world_size, _distributed_worker,
(world_size, fsdp_config, temp_files[0], temp_files[1]),
nprocs=world_size,
) )
...@@ -79,7 +79,9 @@ class ResBlock(Module): ...@@ -79,7 +79,9 @@ class ResBlock(Module):
self.bn = BatchNorm2d(width_out) self.bn = BatchNorm2d(width_out)
self.f = Sequential( self.f = Sequential(
Sequential( # block a Sequential( # block a
Conv2d(width_in, width_out, (1, 1), (1, 1), bias=False), BatchNorm2d(width_out), ReLU(_relu_inplace), Conv2d(width_in, width_out, (1, 1), (1, 1), bias=False),
BatchNorm2d(width_out),
ReLU(_relu_inplace),
), ),
Sequential( # block b Sequential( # block b
Conv2d(width_out, width_out, (3, 3), (2, 2), (1, 1), groups=2, bias=False), Conv2d(width_out, width_out, (3, 3), (2, 2), (1, 1), groups=2, bias=False),
......
...@@ -93,7 +93,9 @@ def test_shared_weight(temp_files, outer_flat, inner_flat, sharing): ...@@ -93,7 +93,9 @@ def test_shared_weight(temp_files, outer_flat, inner_flat, sharing):
# Run FSDP # Run FSDP
mp.spawn( mp.spawn(
_dist_worker, (world_size, temp_files, outer_flat, inner_flat, sharing), nprocs=world_size, _dist_worker,
(world_size, temp_files, outer_flat, inner_flat, sharing),
nprocs=world_size,
) )
......
...@@ -109,7 +109,9 @@ def test_shared_weight_mevo(temp_files, wrap_middle): ...@@ -109,7 +109,9 @@ def test_shared_weight_mevo(temp_files, wrap_middle):
# Run FSDP # Run FSDP
mp.spawn( mp.spawn(
_dist_worker, (world_size, temp_files, wrap_middle), nprocs=world_size, _dist_worker,
(world_size, temp_files, wrap_middle),
nprocs=world_size,
) )
......
...@@ -80,7 +80,8 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test ...@@ -80,7 +80,8 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test
@skip_if_single_gpu @skip_if_single_gpu
@pytest.mark.parametrize("test_case", [{"inputs": [torch.rand(8, 3)], "assert_ref_out": True}]) @pytest.mark.parametrize("test_case", [{"inputs": [torch.rand(8, 3)], "assert_ref_out": True}])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fsdp_config", [{}, {"flatten_parameters": False}, {"mixed_precision": True}], "fsdp_config",
[{}, {"flatten_parameters": False}, {"mixed_precision": True}],
) )
@pytest.mark.parametrize("world_size", list(range(2, 9))) @pytest.mark.parametrize("world_size", list(range(2, 9)))
def test_one_iteration(world_size, test_case, fsdp_config): def test_one_iteration(world_size, test_case, fsdp_config):
......
...@@ -35,8 +35,7 @@ def _get_mlp(tripwire: bool = False): ...@@ -35,8 +35,7 @@ def _get_mlp(tripwire: bool = False):
return Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) return Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
class Tripwire(torch.nn.Module): class Tripwire(torch.nn.Module):
"""A model made to expose possible corner cases """A model made to expose possible corner cases"""
"""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
...@@ -323,7 +322,10 @@ def test_train_eval_change(): ...@@ -323,7 +322,10 @@ def test_train_eval_change():
world_size = 4 world_size = 4
with temp_files_ctx(num=1) as temp_files: with temp_files_ctx(num=1) as temp_files:
mp.spawn( mp.spawn(
run_test_train_eval_change, args=(world_size, temp_files[0]), nprocs=world_size, join=True, run_test_train_eval_change,
args=(world_size, temp_files[0]),
nprocs=world_size,
join=True,
) )
......
...@@ -169,7 +169,11 @@ def run_ddp_parity( ...@@ -169,7 +169,11 @@ def run_ddp_parity(
def sharded_closure(input_tensor=input_tensor): def sharded_closure(input_tensor=input_tensor):
return closure( return closure(
sharded_ddp_model, sharded_scaler, input_tensor, grad_accumulation, _manual_reduction=manual_reduction, sharded_ddp_model,
sharded_scaler,
input_tensor,
grad_accumulation,
_manual_reduction=manual_reduction,
) )
# Step/scale both # Step/scale both
...@@ -329,7 +333,9 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_b ...@@ -329,7 +333,9 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_b
torch.cuda.synchronize(device) torch.cuda.synchronize(device)
check_same_model_params( check_same_model_params(
sharded_ddp_model, ddp_model, f"DDP parity two optim test failing, step {i}, buffers {reduce_buffer_size}", sharded_ddp_model,
ddp_model,
f"DDP parity two optim test failing, step {i}, buffers {reduce_buffer_size}",
) )
dist.destroy_process_group() dist.destroy_process_group()
......
...@@ -15,7 +15,7 @@ from fairscale.utils.testing import objects_are_equal ...@@ -15,7 +15,7 @@ from fairscale.utils.testing import objects_are_equal
class TestFlattenParams(unittest.TestCase): class TestFlattenParams(unittest.TestCase):
""" Base test class and used for CPU case. """ """Base test class and used for CPU case."""
def _get_module_init_fns(self): def _get_module_init_fns(self):
return [ return [
...@@ -42,7 +42,11 @@ class TestFlattenParams(unittest.TestCase): ...@@ -42,7 +42,11 @@ class TestFlattenParams(unittest.TestCase):
def _get_transformer(self, seed=0): def _get_transformer(self, seed=0):
torch.manual_seed(seed) # keep everything deterministic torch.manual_seed(seed) # keep everything deterministic
module = torch.nn.Transformer( module = torch.nn.Transformer(
d_model=32, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=128, dropout=0.1, d_model=32,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=128,
dropout=0.1,
) )
module.register_buffer("dummy_buffer", torch.tensor(1.0)) module.register_buffer("dummy_buffer", torch.tensor(1.0))
......
...@@ -161,5 +161,10 @@ def test_double_stash_pop_but_isolated(): ...@@ -161,5 +161,10 @@ def test_double_stash_pop_but_isolated():
ns2 = Namespace() ns2 = Namespace()
verify_skippables( verify_skippables(
nn.Sequential(Layer1().isolate(ns1), Layer2().isolate(ns1), Layer3().isolate(ns2), Layer4().isolate(ns2),) nn.Sequential(
Layer1().isolate(ns1),
Layer2().isolate(ns1),
Layer3().isolate(ns2),
Layer4().isolate(ns2),
)
) )
...@@ -97,5 +97,10 @@ def test_correctness(use_fp16, checkpoint, chunks): ...@@ -97,5 +97,10 @@ def test_correctness(use_fp16, checkpoint, chunks):
model = _get_model() model = _get_model()
rmodel, ropt, rloss = _train_reg_model(model) rmodel, ropt, rloss = _train_reg_model(model)
pmodel, popt, ploss = _train_pipe_model(model, use_fp16=use_fp16, checkpoint=checkpoint, chunks=chunks,) pmodel, popt, ploss = _train_pipe_model(
model,
use_fp16=use_fp16,
checkpoint=checkpoint,
chunks=chunks,
)
_check_parity(rmodel, pmodel, ropt, popt, rloss, ploss) _check_parity(rmodel, pmodel, ropt, popt, rloss, ploss)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment