Unverified Commit 1bc96fa8 authored by Christopher Dewan's avatar Christopher Dewan Committed by GitHub
Browse files

[FSDP] Upstream fairseq big changes (#956)



* made gradient predivide factor configurable

* fix lints
Co-authored-by: default avatarYour Name <you@example.com>
parent 3c24beb9
...@@ -331,6 +331,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -331,6 +331,7 @@ class FullyShardedDataParallel(nn.Module):
cpu_offload: bool = False, cpu_offload: bool = False,
offload_config: Optional[OffloadConfig] = None, offload_config: Optional[OffloadConfig] = None,
state_dict_on_rank_0_only: bool = False, state_dict_on_rank_0_only: bool = False,
gradient_predivide_factor: Optional[float] = None,
): ):
try: try:
import torch._C import torch._C
...@@ -399,7 +400,9 @@ class FullyShardedDataParallel(nn.Module): ...@@ -399,7 +400,9 @@ class FullyShardedDataParallel(nn.Module):
# Experimental feature for now. Use at your own risk. # Experimental feature for now. Use at your own risk.
self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False
self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(self.world_size) self.gradient_predivide_factor: float = gradient_predivide_factor or self._get_gradient_predivide_factor(
self.world_size
)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
self.numel_padded_per_param: List[int] = [] self.numel_padded_per_param: List[int] = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment