"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "f50a92869c37bbc54fe08af819304e6ad1011e84"
Unverified Commit 4e438ba1 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] SDP: expose module property fix + unit test (#647)

* fix + unit test
* changelog update
parent b66168da
...@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD ## NEXT - TBD
### Fixed ### Fixed
- setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar - setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar
- SDP: re-expose the module property ([#647](https://github.com/facebookresearch/fairscale/pull/647))
### Added ### Added
- FSDP: better memory usage for reduce bucket ([#633](https://github.com/facebookresearch/fairscale/pull/633)) - FSDP: better memory usage for reduce bucket ([#633](https://github.com/facebookresearch/fairscale/pull/633))
......
...@@ -103,7 +103,9 @@ class ShardedDataParallel(nn.Module): ...@@ -103,7 +103,9 @@ class ShardedDataParallel(nn.Module):
): ):
super().__init__() super().__init__()
self._module = module # This field needs to be exposed to insure interface parity with DDP
self.module = module
self._sharded_optimizers = [sharded_optimizer] if not isinstance(sharded_optimizer, list) else sharded_optimizer self._sharded_optimizers = [sharded_optimizer] if not isinstance(sharded_optimizer, list) else sharded_optimizer
self._enable_broadcast_buffers = broadcast_buffers self._enable_broadcast_buffers = broadcast_buffers
self._auto_refresh_trainable = auto_refresh_trainable self._auto_refresh_trainable = auto_refresh_trainable
...@@ -133,10 +135,10 @@ class ShardedDataParallel(nn.Module): ...@@ -133,10 +135,10 @@ class ShardedDataParallel(nn.Module):
# Expose some of the PytorchDDP attributes, some frameworks rely on them. # Expose some of the PytorchDDP attributes, some frameworks rely on them.
# See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel # See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
# device_id related logic is not present, this is not handled # device_id related logic is not present, this is not handled
devices = {p.device for p in self._module.parameters()} devices = {p.device for p in self.module.parameters()}
self.is_multi_device_module = len(devices) > 1 self.is_multi_device_module = len(devices) > 1
distinct_device_types = {p.device.type for p in self._module.parameters()} distinct_device_types = {p.device.type for p in self.module.parameters()}
assert len(distinct_device_types) == 1, ( assert len(distinct_device_types) == 1, (
"ShardedDataParallel's input module must be on " "ShardedDataParallel's input module must be on "
"the same type of devices, but input module parameters are located on {} different device types." "the same type of devices, but input module parameters are located on {} different device types."
...@@ -161,7 +163,7 @@ class ShardedDataParallel(nn.Module): ...@@ -161,7 +163,7 @@ class ShardedDataParallel(nn.Module):
self._reference_trainable_mask = list(map(_trainable, self._all_params)) self._reference_trainable_mask = list(map(_trainable, self._all_params))
# - setup buckets and tensor views # - setup buckets and tensor views
model_size = sum([p.numel() for p in self._module.parameters()]) model_size = sum([p.numel() for p in self.module.parameters()])
self._buffer_max_size = min(reduce_buffer_size, model_size) self._buffer_max_size = min(reduce_buffer_size, model_size)
if dist.get_world_size(self._process_group) == 1: if dist.get_world_size(self._process_group) == 1:
...@@ -185,7 +187,7 @@ class ShardedDataParallel(nn.Module): ...@@ -185,7 +187,7 @@ class ShardedDataParallel(nn.Module):
self._manual_reduce: List[Callable] = [] self._manual_reduce: List[Callable] = []
# passing a handle to torch.nn.SyncBatchNorm layer # passing a handle to torch.nn.SyncBatchNorm layer
self._passing_sync_batchnorm_handle(self._module) self._passing_sync_batchnorm_handle(self.module)
# Make sure that all ranks start with the same model # Make sure that all ranks start with the same model
if sync_models_at_startup: if sync_models_at_startup:
...@@ -219,7 +221,7 @@ class ShardedDataParallel(nn.Module): ...@@ -219,7 +221,7 @@ class ShardedDataParallel(nn.Module):
self._clear_counters() self._clear_counters()
# Normal FW on the base model # Normal FW on the base model
return self._module(*inputs, **kwargs) return self.module(*inputs, **kwargs)
def to( # type: ignore def to( # type: ignore
self, self,
...@@ -267,7 +269,7 @@ class ShardedDataParallel(nn.Module): ...@@ -267,7 +269,7 @@ class ShardedDataParallel(nn.Module):
for bucket in self._buckets[_device].values(): for bucket in self._buckets[_device].values():
bucket.to(device=_device, dtype=dtype, non_blocking=non_blocking) bucket.to(device=_device, dtype=dtype, non_blocking=non_blocking)
self._module.to(device=device, dtype=dtype, non_blocking=non_blocking) self.module.to(device=device, dtype=dtype, non_blocking=non_blocking)
def refresh_trainable(self) -> None: def refresh_trainable(self) -> None:
""" If the module trainability has changed, update all the assumptions """ """ If the module trainability has changed, update all the assumptions """
...@@ -328,7 +330,7 @@ class ShardedDataParallel(nn.Module): ...@@ -328,7 +330,7 @@ class ShardedDataParallel(nn.Module):
with profiler.record_function("fairscale::sdp::sync_buffers"): with profiler.record_function("fairscale::sdp::sync_buffers"):
work_handles = [] work_handles = []
for buffer in self._module.buffers(recurse=True): for buffer in self.module.buffers(recurse=True):
work_handles.append( work_handles.append(
dist.broadcast(buffer.data, self._reference_global_rank, self._process_group, async_op=True) dist.broadcast(buffer.data, self._reference_global_rank, self._process_group, async_op=True)
) )
...@@ -362,7 +364,7 @@ class ShardedDataParallel(nn.Module): ...@@ -362,7 +364,7 @@ class ShardedDataParallel(nn.Module):
try: try:
return super().__getattr__(name) # defer to nn.Module's logic return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError: except AttributeError:
return getattr(self._module, name) return getattr(self.module, name)
@contextlib.contextmanager @contextlib.contextmanager
def no_sync(self) -> Generator: def no_sync(self) -> Generator:
...@@ -528,7 +530,7 @@ class ShardedDataParallel(nn.Module): ...@@ -528,7 +530,7 @@ class ShardedDataParallel(nn.Module):
work_handles = [] work_handles = []
for t in self._module.state_dict().values(): for t in self.module.state_dict().values():
work_handles.append( work_handles.append(
dist.broadcast(t, src=self._reference_global_rank, group=self._process_group, async_op=True) dist.broadcast(t, src=self._reference_global_rank, group=self._process_group, async_op=True)
) )
......
...@@ -236,6 +236,7 @@ def test_ddp_attributes(): ...@@ -236,6 +236,7 @@ def test_ddp_attributes():
assert hasattr(ddp_model, "is_multi_device_module") assert hasattr(ddp_model, "is_multi_device_module")
assert hasattr(ddp_model, "device_type") assert hasattr(ddp_model, "device_type")
assert hasattr(ddp_model, "module")
dist.destroy_process_group() dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment