Unverified Commit 945b9666 authored by Nicholas Cilfone's avatar Nicholas Cilfone Committed by GitHub
Browse files

[refactor] ShardedGradScaler init and super call (#691)

Make ShardedGradScaler __init__ mirror GradScaler so super can forward parameters. Without this one cannot configure a ShardedGradScaler object like one can with the PyTorch native GradScaler object.
Updated with black linter.
Added stub for GradScaler __init__ which solves mypy issues and removed
ignore comment.
parent 8a05ff76
...@@ -30,8 +30,22 @@ class ShardedGradScaler(TorchGradScaler): ...@@ -30,8 +30,22 @@ class ShardedGradScaler(TorchGradScaler):
documentation https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler documentation https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
""" """
def __init__(self, process_group: Any = dist.group.WORLD) -> None: def __init__(
super().__init__() self,
init_scale: float = 2.0 ** 16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
enabled: bool = True,
process_group: Any = dist.group.WORLD,
) -> None:
super().__init__(
init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
enabled=enabled,
)
self.display_warning = True self.display_warning = True
self.group = process_group self.group = process_group
......
...@@ -9,7 +9,8 @@ class GradScaler(object): ...@@ -9,7 +9,8 @@ class GradScaler(object):
_grows_tracker: Optional[Tensor] _grows_tracker: Optional[Tensor]
_per_optimizer_states: Dict[int, Dict[str, Any]] _per_optimizer_states: Dict[int, Dict[str, Any]]
def _unscale_grads_(self, optimizer: Optimizer, inv_scale: Tensor, found_inf: Tensor, allow_fp16: bool) -> Dict[device, Tensor]:... def __init__(self, init_scale: float, growth_factor: float, backoff_factor: float, growth_interval: int, enabled: bool): ...
def _unscale_grads_(self, optimizer: Optimizer, inv_scale: Tensor, found_inf: Tensor, allow_fp16: bool) -> Dict[device, Tensor]: ...
def step(self, optimizer: Optimizer, *args: Any, **kwargs: Any): ... def step(self, optimizer: Optimizer, *args: Any, **kwargs: Any): ...
def update(self, new_scale: Optional[float]=None): ... def update(self, new_scale: Optional[float]=None): ...
def unscale_(self, optimizer: Optimizer) -> None: ... def unscale_(self, optimizer: Optimizer) -> None: ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment