Unverified Commit 91c7dd05 authored by Yanli Zhao's avatar Yanli Zhao Committed by GitHub
Browse files

Make sure requires_grad of FlatParameter to be consistent with requires_grad...

Make sure requires_grad of FlatParameter to be consistent with requires_grad of original parameters (#721)

* Make sure requires_grad of FlatParameter to be consistent with requires_grad of original parameters

* Make sure requires_grad of FlatParameter to be consistent with requires_grad of original parameters
parent e2c39426
......@@ -182,7 +182,10 @@ class FlattenParamsWrapper(nn.Module):
# Init all flat_params.
for new_p_set in self._param_sets:
params = self._init_flatten_params(new_p_set)
flat_param = FlatParameter(params)
assert (
len(set(p.requires_grad for p in params)) == 1
), "expects all parameters in the same parameter group of the module to have same requires_grad"
flat_param = FlatParameter(params, params[0].requires_grad)
self.flat_params.append(flat_param)
self._flatten_params(self.flat_params)
......
......@@ -11,6 +11,7 @@
from enum import Enum
from itertools import product
import tempfile
import pytest
......@@ -25,7 +26,7 @@ from fairscale.utils.testing import dist_init, objects_are_equal, rmf, skip_if_s
class Model(nn.Module):
def __init__(self):
def __init__(self, with_fsdp, freeze_after_wrap_fsdp):
super().__init__()
self.trunk = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3),
......@@ -34,38 +35,48 @@ class Model(nn.Module):
nn.Flatten(),
)
self.head = nn.Linear(64, 10)
if with_fsdp and freeze_after_wrap_fsdp:
self.fsdp_wrap()
def fsdp_wrap(self):
self.trunk = FSDP(self.trunk)
self.head = FSDP(self.head)
def forward(self, x):
return self.head(self.trunk(x))
class NestedTrunkModel(nn.Module):
def __init__(self, with_fsdp):
def __init__(self, with_fsdp, freeze_after_wrap_fsdp):
super().__init__()
self.trunk = nn.Sequential(self._create_block(3, 64, with_fsdp), self._create_block(64, 64, with_fsdp),)
self.trunk = nn.Sequential(
self._create_block(3, 64, with_fsdp, freeze_after_wrap_fsdp),
self._create_block(64, 64, with_fsdp, freeze_after_wrap_fsdp),
)
self.head = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(64, 10),)
if with_fsdp:
self.trunk = FSDP(self.trunk)
self.head = FSDP(self.head)
if with_fsdp and freeze_after_wrap_fsdp:
self.fsdp_wrap()
def fsdp_wrap(self):
for name, child in self.trunk.named_children():
wrapped_child = FSDP(child)
setattr(self.trunk, name, wrapped_child)
self.trunk = FSDP(self.trunk)
self.head = FSDP(self.head)
def forward(self, x):
return self.head(self.trunk(x))
def _create_block(self, in_channels, out_channels, with_fsdp):
def _create_block(self, in_channels, out_channels, with_fsdp, freeze_after_wrap_fsdp):
block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3), nn.ReLU(inplace=True),)
if with_fsdp:
block = FSDP(block)
return block
def _create_model(with_fsdp, with_nested_trunk):
def _create_model(with_fsdp, with_nested_trunk, freeze_after_wrap_fsdp):
if with_nested_trunk:
model = NestedTrunkModel(with_fsdp)
model = NestedTrunkModel(with_fsdp, freeze_after_wrap_fsdp)
else:
model = Model()
if with_fsdp:
model.trunk = FSDP(model.trunk)
model.head = FSDP(model.head)
model = Model(with_fsdp, freeze_after_wrap_fsdp)
return model
......@@ -80,6 +91,7 @@ def _distributed_worker(
with_fsdp,
with_nested_trunk,
freezing_method,
freeze_after_wrap_fsdp,
tempfile_name,
unused,
rank_0_output,
......@@ -95,7 +107,7 @@ def _distributed_worker(
torch.backends.cudnn.deterministic = True
batch = torch.randn(size=(2, 3, 224, 224)).cuda()
model = _create_model(with_fsdp, with_nested_trunk)
model = _create_model(with_fsdp, with_nested_trunk, freeze_after_wrap_fsdp)
model = model.cuda()
# freezing the trunk using requires_grad.
......@@ -104,6 +116,8 @@ def _distributed_worker(
param.requires_grad = False
if with_fsdp:
if not freeze_after_wrap_fsdp:
model.fsdp_wrap()
model = FSDP(model)
else:
model = DistributedDataParallel(model, device_ids=[gpu_id])
......@@ -142,7 +156,7 @@ def _distributed_worker(
# A fixture to get tempfiles and ensure they are cleaned up.
@pytest.fixture()
def temp_files():
num = 9 # 1 DDP and 2 FSDP cases each needs 3 files.
num = 15 # 1 DDP and 4 FSDP cases each needs 3 files.
files = [tempfile.mkstemp()[1] for _ in range(num)]
yield tuple(files)
......@@ -163,18 +177,20 @@ def test_freezing_weights(temp_files, nested_trunk):
freezing_method = FreezingMethod.RequiresGrad
mp.spawn(
_distributed_worker,
(world_size, with_fsdp, with_nested_trunk, freezing_method) + temp_files[0:3] + (None,),
(world_size, with_fsdp, with_nested_trunk, freezing_method, True) + temp_files[0:3] + (None,),
nprocs=world_size,
)
# FSDP, case 1 and 2.
with_fsdp = True
expected_state = torch.load(temp_files[2])
temp_file_idx = 3
for freezing_method in [FreezingMethod.RequiresGrad, FreezingMethod.GradToNone]:
for freezing_method, freeze_after_wrap_fsdp in product(
[FreezingMethod.RequiresGrad, FreezingMethod.GradToNone], [True, False]
):
print(f"Testing FSDP with freezing method {freezing_method}")
mp.spawn(
_distributed_worker,
(world_size, with_fsdp, with_nested_trunk, freezing_method)
(world_size, with_fsdp, with_nested_trunk, freezing_method, freeze_after_wrap_fsdp)
+ temp_files[temp_file_idx : temp_file_idx + 3]
+ (expected_state,),
nprocs=world_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment