Unverified Commit ef7146d5 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix][minor] ShardedDDP train/eval modes (#393)

* [fix] ShardedDDP train/eval modes
* Update CHANGELOG.md
parent 47042917
...@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed ### Fixed
- ShardedDDP and OSS handle model trainability changes during training ([#369](https://github.com/facebookresearch/fairscale/issues/369)) - ShardedDDP and OSS handle model trainability changes during training ([#369](https://github.com/facebookresearch/fairscale/issues/369))
- ShardedDDP state dict load/save bug (#386) - ShardedDDP state dict load/save bug (#386)
- ShardedDDP handle train/eval modes (#393)
### Added ### Added
- ShardedDDP manual reduce option for checkpointing (#389) - ShardedDDP manual reduce option for checkpointing (#389)
......
...@@ -352,7 +352,7 @@ class ShardedDataParallel(nn.Module): ...@@ -352,7 +352,7 @@ class ShardedDataParallel(nn.Module):
assert self._bucket_list is not None assert self._bucket_list is not None
for bucket in self._bucket_list: for bucket in self._bucket_list:
assert self.should_accumulate_grads or bucket.sent, ( assert not self.training or self.should_accumulate_grads or bucket.sent, (
"A bucket failed to be sent, probably unused parameters." "A bucket failed to be sent, probably unused parameters."
+ "Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-" + "Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
) )
......
...@@ -440,6 +440,32 @@ def test_device_change(): ...@@ -440,6 +440,32 @@ def test_device_change():
mp.spawn(run_test_device_change, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True) mp.spawn(run_test_device_change, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def run_test_training_change(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name
group = dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
model = Sequential(Linear(2, 3), Linear(3, 3)).to(device)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer, process_group=group)
inputs = torch.rand((10, 2), device=device)
outputs = ddp_model(inputs) # assert if the module has not been changed properly
_ = outputs.norm().backward()
ddp_model.eval()
ddp_model(inputs) # This will assert if eval() is not properly taken into account
ddp_model(inputs)
dist.destroy_process_group()
def test_training_change():
world_size = 8
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cpu"
mp.spawn(run_test_training_change, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name): def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment