"git@developer.sourcefind.cn:OpenDAS/torch-spline-conv.git" did not exist on "1b9ec8ee61def747dc169b52fa7966237e1f110d"
Unverified Commit 04001e76 authored by Shruti Bhosale's avatar Shruti Bhosale Committed by GitHub
Browse files

[FSDP] Add gradient predivide factor to avoid overflow/underflow with large world size (#565)

parent 5a3df0da
...@@ -187,6 +187,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -187,6 +187,8 @@ class FullyShardedDataParallel(nn.Module):
self.buffer_dtype = buffer_dtype or self.compute_dtype self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb self.bucket_cap_mb = bucket_cap_mb
self.gradient_predivide_factor: int = self.get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
self.numel_padded_per_param: List[int] = [] self.numel_padded_per_param: List[int] = []
self.compute_device = compute_device self.compute_device = compute_device
...@@ -252,6 +254,12 @@ class FullyShardedDataParallel(nn.Module): ...@@ -252,6 +254,12 @@ class FullyShardedDataParallel(nn.Module):
# user explicitly requests the local state dict via local_state_dict(). # user explicitly requests the local state dict via local_state_dict().
self._return_full_state_dict = True self._return_full_state_dict = True
def get_gradient_predivide_factor(self, world_size: int) -> int:
factor = 1
while world_size % factor == 0 and world_size / factor > factor:
factor = factor * 2
return factor
@property @property
def module(self) -> nn.Module: def module(self) -> nn.Module:
return self._fsdp_wrapped_module # note: may be a FlattenParamsWrapper instance return self._fsdp_wrapped_module # note: may be a FlattenParamsWrapper instance
...@@ -1069,9 +1077,9 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1069,9 +1077,9 @@ class FullyShardedDataParallel(nn.Module):
# Cast grad to FP32. # Cast grad to FP32.
param.grad.data = param.grad.data.to(param.dtype) param.grad.data = param.grad.data.to(param.dtype)
if self.world_size > 1: if self.gradient_predivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP. # Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.world_size) param.grad.data.div_(self.gradient_predivide_factor)
callback_fn = functools.partial(self._post_reduction_hook, param) callback_fn = functools.partial(self._post_reduction_hook, param)
if param._is_sharded: if param._is_sharded:
...@@ -1099,6 +1107,9 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1099,6 +1107,9 @@ class FullyShardedDataParallel(nn.Module):
assert param.grad is not None assert param.grad is not None
self.assert_state(TrainingState.BACKWARD_POST) self.assert_state(TrainingState.BACKWARD_POST)
param.grad.data = reduced_grad param.grad.data = reduced_grad
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.gradient_postdivide_factor)
# Cast grad to param's dtype (typically FP32). Note: we do this # Cast grad to param's dtype (typically FP32). Note: we do this
# before the move_grads_to_cpu step so that this entire hook remains # before the move_grads_to_cpu step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case. # non-blocking. The downside is a bit more D2H transfer in that case.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment