Unverified Commit b54eed1b authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] better assert and better test for frozen weights (#657)



* [fix] better assert and better test for frozen weights

- the precise condition should have been check m.parameters(), not
  m.params.
- fixes #643

* add changelog

* use enum is so much better
Co-authored-by: default avatarMin Xu <min.xu@acm.org>
parent 1ae77784
...@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD ## NEXT - TBD
### Fixed ### Fixed
- FSDP: improved frozen weight support
- FSDP: workaround AMP autocast cache issue with clear\_autocast\_cache flag - FSDP: workaround AMP autocast cache issue with clear\_autocast\_cache flag
- setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar - setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar
- SDP: re-expose the module property ([#647](https://github.com/facebookresearch/fairscale/pull/647)) - SDP: re-expose the module property ([#647](https://github.com/facebookresearch/fairscale/pull/647))
......
...@@ -1301,15 +1301,19 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1301,15 +1301,19 @@ class FullyShardedDataParallel(nn.Module):
if isinstance(m, FullyShardedDataParallel): if isinstance(m, FullyShardedDataParallel):
_remove_shard_bwd_hook(m) _remove_shard_bwd_hook(m)
m._pre_backward_hook_has_run = False m._pre_backward_hook_has_run = False
if m._has_params: # Note: m.parameters() should not be an empty list. FSDP
if any(p.requires_grad for p in m.params): # wrapping modules without weights is not tested at the moment.
if any(p.requires_grad for p in m.parameters()):
if m._has_params:
m.assert_state(TrainingState.BACKWARD_POST) m.assert_state(TrainingState.BACKWARD_POST)
else: else:
# Unlikely case, should only happens if `m` has params but none of the m.assert_state(TrainingState.BACKWARD_PRE)
# params has `requires_grad==True`.
m.assert_state(TrainingState.IDLE)
else: else:
m.assert_state(TrainingState.BACKWARD_PRE) # Unlikely case. When m and its children has no params
# with `requires_grad==True`, then m's pre-backward and
# post-backward hooks aren't called by autograd. Therefore,
# it is in IDLE state.
m.assert_state(TrainingState.IDLE)
m.training_state = TrainingState.IDLE m.training_state = TrainingState.IDLE
@torch.no_grad() @torch.no_grad()
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
""" Test FSDP with some params frozen. """ """ Test FSDP with some params frozen. """
from enum import Enum
import tempfile import tempfile
import pytest import pytest
...@@ -38,16 +39,51 @@ class Model(nn.Module): ...@@ -38,16 +39,51 @@ class Model(nn.Module):
return self.head(self.trunk(x)) return self.head(self.trunk(x))
def _create_model(with_fsdp): class NestedTrunkModel(nn.Module):
model = Model() def __init__(self, with_fsdp):
if with_fsdp: super().__init__()
model.trunk = FSDP(model.trunk) self.trunk = nn.Sequential(self._create_block(3, 64, with_fsdp), self._create_block(64, 64, with_fsdp),)
model.head = FSDP(model.head) self.head = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(64, 10),)
if with_fsdp:
self.trunk = FSDP(self.trunk)
self.head = FSDP(self.head)
def forward(self, x):
return self.head(self.trunk(x))
def _create_block(self, in_channels, out_channels, with_fsdp):
block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3), nn.ReLU(inplace=True),)
if with_fsdp:
block = FSDP(block)
return block
def _create_model(with_fsdp, with_nested_trunk):
if with_nested_trunk:
model = NestedTrunkModel(with_fsdp)
else:
model = Model()
if with_fsdp:
model.trunk = FSDP(model.trunk)
model.head = FSDP(model.head)
return model return model
class FreezingMethod(str, Enum):
GradToNone = "grad_to_none"
RequiresGrad = "requires_grad"
def _distributed_worker( def _distributed_worker(
gpu_id, world_size, with_fsdp, freezing_method, tempfile_name, unused, rank_0_output, expected_state gpu_id,
world_size,
with_fsdp,
with_nested_trunk,
freezing_method,
tempfile_name,
unused,
rank_0_output,
expected_state,
): ):
torch.cuda.set_device(gpu_id) torch.cuda.set_device(gpu_id)
...@@ -59,12 +95,11 @@ def _distributed_worker( ...@@ -59,12 +95,11 @@ def _distributed_worker(
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
batch = torch.randn(size=(2, 3, 224, 224)).cuda() batch = torch.randn(size=(2, 3, 224, 224)).cuda()
model = _create_model(with_fsdp) model = _create_model(with_fsdp, with_nested_trunk)
model = model.cuda() model = model.cuda()
# freezing the trunk using requires_grad. # freezing the trunk using requires_grad.
assert freezing_method in ["requires_grad", "grad_to_none"] if freezing_method == FreezingMethod.RequiresGrad:
if freezing_method == "requires_grad":
for param in model.trunk.parameters(): for param in model.trunk.parameters():
param.requires_grad = False param.requires_grad = False
...@@ -86,7 +121,7 @@ def _distributed_worker( ...@@ -86,7 +121,7 @@ def _distributed_worker(
print("Loss", iteration, ":", fake_loss.item()) print("Loss", iteration, ":", fake_loss.item())
optimizer.zero_grad() optimizer.zero_grad()
fake_loss.backward() fake_loss.backward()
if freezing_method == "grad_to_none": if freezing_method == FreezingMethod.GradToNone:
for param in model.trunk.parameters(): for param in model.trunk.parameters():
param.grad = None param.grad = None
optimizer.step() optimizer.step()
...@@ -118,21 +153,30 @@ def temp_files(): ...@@ -118,21 +153,30 @@ def temp_files():
@skip_if_single_gpu @skip_if_single_gpu
def test_freezing_weights(temp_files): @pytest.mark.parametrize("nested_trunk", ["nested_trunk", "simple_trunk"])
def test_freezing_weights(temp_files, nested_trunk):
with_nested_trunk = nested_trunk == "nested_trunk"
world_size = 2 world_size = 2
# DDP # DDP
fsdp = False with_fsdp = False
freezing_method = "requires_grad" freezing_method = FreezingMethod.RequiresGrad
mp.spawn(_distributed_worker, (world_size, fsdp, freezing_method) + temp_files[0:3] + (None,), nprocs=world_size) mp.spawn(
_distributed_worker,
(world_size, with_fsdp, with_nested_trunk, freezing_method) + temp_files[0:3] + (None,),
nprocs=world_size,
)
# FSDP, case 1 and 2. # FSDP, case 1 and 2.
fsdp = True with_fsdp = True
expected_state = torch.load(temp_files[2]) expected_state = torch.load(temp_files[2])
temp_file_idx = 3 temp_file_idx = 3
for freezing_method in ["requires_grad", "grad_to_none"]: for freezing_method in [FreezingMethod.RequiresGrad, FreezingMethod.GradToNone]:
print(f"Testing FSDP with freezing method {freezing_method}") print(f"Testing FSDP with freezing method {freezing_method}")
mp.spawn( mp.spawn(
_distributed_worker, _distributed_worker,
(world_size, fsdp, freezing_method) + temp_files[temp_file_idx : temp_file_idx + 3] + (expected_state,), (world_size, with_fsdp, with_nested_trunk, freezing_method)
+ temp_files[temp_file_idx : temp_file_idx + 3]
+ (expected_state,),
nprocs=world_size, nprocs=world_size,
) )
temp_file_idx += 3 temp_file_idx += 3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment