Unverified Commit b54eed1b authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] better assert and better test for frozen weights (#657)



* [fix] better assert and better test for frozen weights

- the precise condition should have been check m.parameters(), not
  m.params.
- fixes #643

* add changelog

* use enum is so much better
Co-authored-by: default avatarMin Xu <min.xu@acm.org>
parent 1ae77784
......@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD
### Fixed
- FSDP: improved frozen weight support
- FSDP: workaround AMP autocast cache issue with clear\_autocast\_cache flag
- setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar
- SDP: re-expose the module property ([#647](https://github.com/facebookresearch/fairscale/pull/647))
......
......@@ -1301,15 +1301,19 @@ class FullyShardedDataParallel(nn.Module):
if isinstance(m, FullyShardedDataParallel):
_remove_shard_bwd_hook(m)
m._pre_backward_hook_has_run = False
if m._has_params:
if any(p.requires_grad for p in m.params):
# Note: m.parameters() should not be an empty list. FSDP
# wrapping modules without weights is not tested at the moment.
if any(p.requires_grad for p in m.parameters()):
if m._has_params:
m.assert_state(TrainingState.BACKWARD_POST)
else:
# Unlikely case, should only happens if `m` has params but none of the
# params has `requires_grad==True`.
m.assert_state(TrainingState.IDLE)
m.assert_state(TrainingState.BACKWARD_PRE)
else:
m.assert_state(TrainingState.BACKWARD_PRE)
# Unlikely case. When m and its children has no params
# with `requires_grad==True`, then m's pre-backward and
# post-backward hooks aren't called by autograd. Therefore,
# it is in IDLE state.
m.assert_state(TrainingState.IDLE)
m.training_state = TrainingState.IDLE
@torch.no_grad()
......
......@@ -10,6 +10,7 @@
""" Test FSDP with some params frozen. """
from enum import Enum
import tempfile
import pytest
......@@ -38,16 +39,51 @@ class Model(nn.Module):
return self.head(self.trunk(x))
def _create_model(with_fsdp):
model = Model()
if with_fsdp:
model.trunk = FSDP(model.trunk)
model.head = FSDP(model.head)
class NestedTrunkModel(nn.Module):
def __init__(self, with_fsdp):
super().__init__()
self.trunk = nn.Sequential(self._create_block(3, 64, with_fsdp), self._create_block(64, 64, with_fsdp),)
self.head = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(64, 10),)
if with_fsdp:
self.trunk = FSDP(self.trunk)
self.head = FSDP(self.head)
def forward(self, x):
return self.head(self.trunk(x))
def _create_block(self, in_channels, out_channels, with_fsdp):
block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3), nn.ReLU(inplace=True),)
if with_fsdp:
block = FSDP(block)
return block
def _create_model(with_fsdp, with_nested_trunk):
if with_nested_trunk:
model = NestedTrunkModel(with_fsdp)
else:
model = Model()
if with_fsdp:
model.trunk = FSDP(model.trunk)
model.head = FSDP(model.head)
return model
class FreezingMethod(str, Enum):
GradToNone = "grad_to_none"
RequiresGrad = "requires_grad"
def _distributed_worker(
gpu_id, world_size, with_fsdp, freezing_method, tempfile_name, unused, rank_0_output, expected_state
gpu_id,
world_size,
with_fsdp,
with_nested_trunk,
freezing_method,
tempfile_name,
unused,
rank_0_output,
expected_state,
):
torch.cuda.set_device(gpu_id)
......@@ -59,12 +95,11 @@ def _distributed_worker(
torch.backends.cudnn.deterministic = True
batch = torch.randn(size=(2, 3, 224, 224)).cuda()
model = _create_model(with_fsdp)
model = _create_model(with_fsdp, with_nested_trunk)
model = model.cuda()
# freezing the trunk using requires_grad.
assert freezing_method in ["requires_grad", "grad_to_none"]
if freezing_method == "requires_grad":
if freezing_method == FreezingMethod.RequiresGrad:
for param in model.trunk.parameters():
param.requires_grad = False
......@@ -86,7 +121,7 @@ def _distributed_worker(
print("Loss", iteration, ":", fake_loss.item())
optimizer.zero_grad()
fake_loss.backward()
if freezing_method == "grad_to_none":
if freezing_method == FreezingMethod.GradToNone:
for param in model.trunk.parameters():
param.grad = None
optimizer.step()
......@@ -118,21 +153,30 @@ def temp_files():
@skip_if_single_gpu
def test_freezing_weights(temp_files):
@pytest.mark.parametrize("nested_trunk", ["nested_trunk", "simple_trunk"])
def test_freezing_weights(temp_files, nested_trunk):
with_nested_trunk = nested_trunk == "nested_trunk"
world_size = 2
# DDP
fsdp = False
freezing_method = "requires_grad"
mp.spawn(_distributed_worker, (world_size, fsdp, freezing_method) + temp_files[0:3] + (None,), nprocs=world_size)
with_fsdp = False
freezing_method = FreezingMethod.RequiresGrad
mp.spawn(
_distributed_worker,
(world_size, with_fsdp, with_nested_trunk, freezing_method) + temp_files[0:3] + (None,),
nprocs=world_size,
)
# FSDP, case 1 and 2.
fsdp = True
with_fsdp = True
expected_state = torch.load(temp_files[2])
temp_file_idx = 3
for freezing_method in ["requires_grad", "grad_to_none"]:
for freezing_method in [FreezingMethod.RequiresGrad, FreezingMethod.GradToNone]:
print(f"Testing FSDP with freezing method {freezing_method}")
mp.spawn(
_distributed_worker,
(world_size, fsdp, freezing_method) + temp_files[temp_file_idx : temp_file_idx + 3] + (expected_state,),
(world_size, with_fsdp, with_nested_trunk, freezing_method)
+ temp_files[temp_file_idx : temp_file_idx + 3]
+ (expected_state,),
nprocs=world_size,
)
temp_file_idx += 3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment