Unverified Commit 73f73120 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[fix] Add an additional assert for checking if the params of a module requires_grad=True (#761)

* add additional assert for checking if the requires_grad field is set.

* fix lint errors

* add unit tests and address comments
parent a825348d
......@@ -1405,7 +1405,11 @@ class FullyShardedDataParallel(nn.Module):
def _wait_for_post_backward(self) -> None:
"""Wait for post-backward to finish. Only called on root instance."""
assert self._is_root
if self._has_params:
# Check if the root module has params and if any of them has
# the `requires_grad` field set. If `requires_grad=False` for
# all the params, the post_backward hook will not fire and the
# state will remain in `TrainingState.BACKWARD_PRE`.
if any([p.requires_grad for p in self.params]):
self.assert_state(TrainingState.BACKWARD_POST)
else:
self.assert_state(TrainingState.BACKWARD_PRE)
......@@ -1441,7 +1445,11 @@ class FullyShardedDataParallel(nn.Module):
_remove_shard_bwd_hook(m)
m._pre_backward_hook_has_run = False
if any(p.requires_grad for p in m.parameters()):
if m._has_params:
# Check if the module has params and if any of them has
# the `requires_grad` field set. If `requires_grad=False` for
# all the params, the post_backward hook will not fire and the
# state will remain in `TrainingState.BACKWARD_PRE`.
if any([p.requires_grad for p in m.params]):
m.assert_state(TrainingState.BACKWARD_POST)
else:
m.assert_state(TrainingState.BACKWARD_PRE)
......
......@@ -25,6 +25,73 @@ from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import dist_init, objects_are_equal, rmf, skip_if_single_gpu, teardown
class FreezeModel(nn.Module):
def __init__(self):
super().__init__()
self.trunk = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(),
)
self.head = nn.Linear(64, 10)
self.trunk = FSDP(self.trunk)
def forward(self, x):
return self.head(self.trunk(x))
def _freeze_distributed_worker(
gpu_id, world_size, tempfile_name, unused,
):
torch.cuda.set_device(gpu_id)
rank = gpu_id
result = dist_init(rank, world_size, tempfile_name, unused)
assert result, "Dist init failed"
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
batch = torch.randn(size=(2, 3, 224, 224)).cuda()
# The use case for this test is where the weights in the submodule
# are not frozen but the leftover weights or those contained by the
# root module are frozen. Refer to issue #758 for a real world example.
model = FreezeModel()
model = model.cuda()
for param in model.head.parameters():
param.requires_grad = False
model = FSDP(model)
if gpu_id == 0:
print(model)
target = torch.tensor([0, 1], dtype=torch.long).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
for iteration in range(3):
out = model(batch)
fake_loss = criterion(out, target)
print("Loss", iteration, ":", fake_loss.item())
optimizer.zero_grad()
fake_loss.backward()
optimizer.step()
teardown()
@skip_if_single_gpu
def test_submodule_freezing_weights(temp_files):
world_size = 2
mp.spawn(
_freeze_distributed_worker, (world_size, temp_files[0], temp_files[1]), nprocs=world_size,
)
class Model(nn.Module):
def __init__(self, with_fsdp, freeze_after_wrap_fsdp):
super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment