Unverified Commit 91c7dd05 authored by Yanli Zhao's avatar Yanli Zhao Committed by GitHub
Browse files

Make sure requires_grad of FlatParameter to be consistent with requires_grad...

Make sure requires_grad of FlatParameter to be consistent with requires_grad of original parameters (#721)

* Make sure requires_grad of FlatParameter to be consistent with requires_grad of original parameters

* Make sure requires_grad of FlatParameter to be consistent with requires_grad of original parameters
parent e2c39426
...@@ -182,7 +182,10 @@ class FlattenParamsWrapper(nn.Module): ...@@ -182,7 +182,10 @@ class FlattenParamsWrapper(nn.Module):
# Init all flat_params. # Init all flat_params.
for new_p_set in self._param_sets: for new_p_set in self._param_sets:
params = self._init_flatten_params(new_p_set) params = self._init_flatten_params(new_p_set)
flat_param = FlatParameter(params) assert (
len(set(p.requires_grad for p in params)) == 1
), "expects all parameters in the same parameter group of the module to have same requires_grad"
flat_param = FlatParameter(params, params[0].requires_grad)
self.flat_params.append(flat_param) self.flat_params.append(flat_param)
self._flatten_params(self.flat_params) self._flatten_params(self.flat_params)
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
from enum import Enum from enum import Enum
from itertools import product
import tempfile import tempfile
import pytest import pytest
...@@ -25,7 +26,7 @@ from fairscale.utils.testing import dist_init, objects_are_equal, rmf, skip_if_s ...@@ -25,7 +26,7 @@ from fairscale.utils.testing import dist_init, objects_are_equal, rmf, skip_if_s
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self, with_fsdp, freeze_after_wrap_fsdp):
super().__init__() super().__init__()
self.trunk = nn.Sequential( self.trunk = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3), nn.Conv2d(3, 64, kernel_size=3),
...@@ -34,38 +35,48 @@ class Model(nn.Module): ...@@ -34,38 +35,48 @@ class Model(nn.Module):
nn.Flatten(), nn.Flatten(),
) )
self.head = nn.Linear(64, 10) self.head = nn.Linear(64, 10)
if with_fsdp and freeze_after_wrap_fsdp:
self.fsdp_wrap()
def fsdp_wrap(self):
self.trunk = FSDP(self.trunk)
self.head = FSDP(self.head)
def forward(self, x): def forward(self, x):
return self.head(self.trunk(x)) return self.head(self.trunk(x))
class NestedTrunkModel(nn.Module): class NestedTrunkModel(nn.Module):
def __init__(self, with_fsdp): def __init__(self, with_fsdp, freeze_after_wrap_fsdp):
super().__init__() super().__init__()
self.trunk = nn.Sequential(self._create_block(3, 64, with_fsdp), self._create_block(64, 64, with_fsdp),) self.trunk = nn.Sequential(
self._create_block(3, 64, with_fsdp, freeze_after_wrap_fsdp),
self._create_block(64, 64, with_fsdp, freeze_after_wrap_fsdp),
)
self.head = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(64, 10),) self.head = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(64, 10),)
if with_fsdp: if with_fsdp and freeze_after_wrap_fsdp:
self.fsdp_wrap()
def fsdp_wrap(self):
for name, child in self.trunk.named_children():
wrapped_child = FSDP(child)
setattr(self.trunk, name, wrapped_child)
self.trunk = FSDP(self.trunk) self.trunk = FSDP(self.trunk)
self.head = FSDP(self.head) self.head = FSDP(self.head)
def forward(self, x): def forward(self, x):
return self.head(self.trunk(x)) return self.head(self.trunk(x))
def _create_block(self, in_channels, out_channels, with_fsdp): def _create_block(self, in_channels, out_channels, with_fsdp, freeze_after_wrap_fsdp):
block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3), nn.ReLU(inplace=True),) block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3), nn.ReLU(inplace=True),)
if with_fsdp:
block = FSDP(block)
return block return block
def _create_model(with_fsdp, with_nested_trunk): def _create_model(with_fsdp, with_nested_trunk, freeze_after_wrap_fsdp):
if with_nested_trunk: if with_nested_trunk:
model = NestedTrunkModel(with_fsdp) model = NestedTrunkModel(with_fsdp, freeze_after_wrap_fsdp)
else: else:
model = Model() model = Model(with_fsdp, freeze_after_wrap_fsdp)
if with_fsdp:
model.trunk = FSDP(model.trunk)
model.head = FSDP(model.head)
return model return model
...@@ -80,6 +91,7 @@ def _distributed_worker( ...@@ -80,6 +91,7 @@ def _distributed_worker(
with_fsdp, with_fsdp,
with_nested_trunk, with_nested_trunk,
freezing_method, freezing_method,
freeze_after_wrap_fsdp,
tempfile_name, tempfile_name,
unused, unused,
rank_0_output, rank_0_output,
...@@ -95,7 +107,7 @@ def _distributed_worker( ...@@ -95,7 +107,7 @@ def _distributed_worker(
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
batch = torch.randn(size=(2, 3, 224, 224)).cuda() batch = torch.randn(size=(2, 3, 224, 224)).cuda()
model = _create_model(with_fsdp, with_nested_trunk) model = _create_model(with_fsdp, with_nested_trunk, freeze_after_wrap_fsdp)
model = model.cuda() model = model.cuda()
# freezing the trunk using requires_grad. # freezing the trunk using requires_grad.
...@@ -104,6 +116,8 @@ def _distributed_worker( ...@@ -104,6 +116,8 @@ def _distributed_worker(
param.requires_grad = False param.requires_grad = False
if with_fsdp: if with_fsdp:
if not freeze_after_wrap_fsdp:
model.fsdp_wrap()
model = FSDP(model) model = FSDP(model)
else: else:
model = DistributedDataParallel(model, device_ids=[gpu_id]) model = DistributedDataParallel(model, device_ids=[gpu_id])
...@@ -142,7 +156,7 @@ def _distributed_worker( ...@@ -142,7 +156,7 @@ def _distributed_worker(
# A fixture to get tempfiles and ensure they are cleaned up. # A fixture to get tempfiles and ensure they are cleaned up.
@pytest.fixture() @pytest.fixture()
def temp_files(): def temp_files():
num = 9 # 1 DDP and 2 FSDP cases each needs 3 files. num = 15 # 1 DDP and 4 FSDP cases each needs 3 files.
files = [tempfile.mkstemp()[1] for _ in range(num)] files = [tempfile.mkstemp()[1] for _ in range(num)]
yield tuple(files) yield tuple(files)
...@@ -163,18 +177,20 @@ def test_freezing_weights(temp_files, nested_trunk): ...@@ -163,18 +177,20 @@ def test_freezing_weights(temp_files, nested_trunk):
freezing_method = FreezingMethod.RequiresGrad freezing_method = FreezingMethod.RequiresGrad
mp.spawn( mp.spawn(
_distributed_worker, _distributed_worker,
(world_size, with_fsdp, with_nested_trunk, freezing_method) + temp_files[0:3] + (None,), (world_size, with_fsdp, with_nested_trunk, freezing_method, True) + temp_files[0:3] + (None,),
nprocs=world_size, nprocs=world_size,
) )
# FSDP, case 1 and 2. # FSDP, case 1 and 2.
with_fsdp = True with_fsdp = True
expected_state = torch.load(temp_files[2]) expected_state = torch.load(temp_files[2])
temp_file_idx = 3 temp_file_idx = 3
for freezing_method in [FreezingMethod.RequiresGrad, FreezingMethod.GradToNone]: for freezing_method, freeze_after_wrap_fsdp in product(
[FreezingMethod.RequiresGrad, FreezingMethod.GradToNone], [True, False]
):
print(f"Testing FSDP with freezing method {freezing_method}") print(f"Testing FSDP with freezing method {freezing_method}")
mp.spawn( mp.spawn(
_distributed_worker, _distributed_worker,
(world_size, with_fsdp, with_nested_trunk, freezing_method) (world_size, with_fsdp, with_nested_trunk, freezing_method, freeze_after_wrap_fsdp)
+ temp_files[temp_file_idx : temp_file_idx + 3] + temp_files[temp_file_idx : temp_file_idx + 3]
+ (expected_state,), + (expected_state,),
nprocs=world_size, nprocs=world_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment