Unverified Commit b4fbc2b3 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Use same API in optimizer `zero_grad` as PyTorch optimizers (#1466)



Use same API in optimizer zero_grad as PyT optimizers
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 257345a5
...@@ -3,8 +3,12 @@ ...@@ -3,8 +3,12 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Fused Adam optimizer.""" """Fused Adam optimizer."""
from __future__ import annotations
from collections.abc import Iterable
from copy import deepcopy from copy import deepcopy
from itertools import chain from itertools import chain
from typing import Optional
import warnings
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -52,8 +56,6 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -52,8 +56,6 @@ class FusedAdam(torch.optim.Optimizer):
params (iterable): iterable of parameters to optimize or dicts defining params (iterable): iterable of parameters to optimize or dicts defining
parameter groups. parameter groups.
lr (float, optional): learning rate. (default: 1e-3) lr (float, optional): learning rate. (default: 1e-3)
bias_correction (bool, optional): apply correction factor to
moment estimates. (default: True)
betas (Tuple[float, float], optional): coefficients used for computing betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999)) running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve eps (float, optional): term added to the denominator to improve
...@@ -62,10 +64,10 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -62,10 +64,10 @@ class FusedAdam(torch.optim.Optimizer):
amsgrad (boolean, optional): whether to use the AMSGrad variant of this amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_ algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam! (default: False) NOT SUPPORTED in FusedAdam!
bias_correction (bool, optional): apply correction factor to
moment estimates. (default: True)
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True) True for decoupled weight decay(also known as AdamW) (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
capturable (bool, optional): whether to use the version of the optimizer capturable (bool, optional): whether to use the version of the optimizer
that can be used with CUDA Graphs. (default: False) that can be used with CUDA Graphs. (default: False)
master_weights (bool, optional): whether to maintain FP32 master weights master_weights (bool, optional): whether to maintain FP32 master weights
...@@ -106,15 +108,15 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -106,15 +108,15 @@ class FusedAdam(torch.optim.Optimizer):
def __init__( def __init__(
self, self,
params, params: Iterable[torch.nn.Parameter | dict],
lr=1e-3, lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
amsgrad: bool = False,
*,
bias_correction=True, bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
adam_w_mode=True, adam_w_mode=True,
weight_decay=0.0,
amsgrad=False,
set_grad_none=True,
capturable=False, capturable=False,
master_weights=False, master_weights=False,
master_weight_dtype=torch.float32, master_weight_dtype=torch.float32,
...@@ -122,6 +124,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -122,6 +124,7 @@ class FusedAdam(torch.optim.Optimizer):
exp_avg_sq_dtype=torch.float32, exp_avg_sq_dtype=torch.float32,
use_decoupled_grad=False, use_decoupled_grad=False,
store_param_remainders=False, store_param_remainders=False,
set_grad_none: Optional[bool] = None, # deprecated
): ):
if amsgrad: if amsgrad:
...@@ -160,7 +163,6 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -160,7 +163,6 @@ class FusedAdam(torch.optim.Optimizer):
} }
super().__init__(params, defaults) super().__init__(params, defaults)
self.adam_w_mode = 1 if adam_w_mode else 0 self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none
self.capturable = capturable self.capturable = capturable
self.master_weights = master_weights self.master_weights = master_weights
...@@ -204,19 +206,46 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -204,19 +206,46 @@ class FusedAdam(torch.optim.Optimizer):
store_param_remainders and master_weights and master_weight_dtype == torch.float32 store_param_remainders and master_weights and master_weight_dtype == torch.float32
) )
def zero_grad(self): # Deprecated options
# pylint: disable=missing-function-docstring self.set_grad_none = set_grad_none
if not self.use_decoupled_grad and not self.set_grad_none: if self.set_grad_none is not None:
super().zero_grad() warnings.warn(
"set_grad_none kwarg in FusedAdam constructor is deprecated. "
"Use set_to_none kwarg in zero_grad instead.",
DeprecationWarning,
)
def zero_grad(self, set_to_none: Optional[bool] = None) -> None:
"""Reset parameter gradients.
Arguments:
set_to_none (bool, optional): whether to set grads to `None`
instead of zeroing out buffers. (default: True)
"""
# Handle deprecated set_grad_none option
if self.set_grad_none is not None:
if set_to_none is not None and set_to_none != self.set_grad_none:
raise ValueError(
f"Called zero_grad with set_to_none={set_to_none}, "
f"but FusedAdam was initialized with set_grad_none={self.set_grad_none}"
)
set_to_none = self.set_grad_none
if set_to_none is None:
set_to_none = True
if not self.use_decoupled_grad and not set_to_none:
super().zero_grad(set_to_none=set_to_none)
return return
for group in self.param_groups: for group in self.param_groups:
for p in group["params"]: for p in group["params"]:
if self.use_decoupled_grad and self.set_grad_none: if self.use_decoupled_grad and set_to_none:
p.decoupled_grad = None p.decoupled_grad = None
elif self.use_decoupled_grad and not self.set_grad_none: elif self.use_decoupled_grad and not set_to_none:
p.decoupled_grad.zero_() p.decoupled_grad.zero_()
elif not self.use_decoupled_grad and self.set_grad_none: elif not self.use_decoupled_grad and set_to_none:
p.grad = None p.grad = None
def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): def _apply_scale(self, state_name, unscaled_state, scaled_state, scale):
......
...@@ -3,6 +3,11 @@ ...@@ -3,6 +3,11 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Fused SGD optimizer.""" """Fused SGD optimizer."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional
import warnings
import torch import torch
from torch.optim.optimizer import Optimizer, required from torch.optim.optimizer import Optimizer, required
...@@ -37,8 +42,8 @@ class FusedSGD(Optimizer): ...@@ -37,8 +42,8 @@ class FusedSGD(Optimizer):
parameter groups parameter groups
lr (float): learning rate lr (float): learning rate
momentum (float, optional): momentum factor (default: 0) momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0) dampening (float, optional): dampening for momentum (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False) nesterov (bool, optional): enables Nesterov momentum (default: False)
Example: Example:
...@@ -74,15 +79,16 @@ class FusedSGD(Optimizer): ...@@ -74,15 +79,16 @@ class FusedSGD(Optimizer):
def __init__( def __init__(
self, self,
params, params: Iterable[torch.nn.Parameter | dict],
lr=required, lr: float | Any = required,
momentum=0, momentum: float = 0.0,
dampening=0, dampening: float = 0.0,
weight_decay=0, weight_decay: float = 0.0,
nesterov=False, nesterov: bool = False,
*,
wd_after_momentum=False, wd_after_momentum=False,
materialize_master_grads=True, materialize_master_grads=True,
set_grad_none=False, set_grad_none: Optional[bool] = None, # deprecated
): ):
if lr is not required and lr < 0.0: if lr is not required and lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}") raise ValueError(f"Invalid learning rate: {lr}")
...@@ -98,7 +104,7 @@ class FusedSGD(Optimizer): ...@@ -98,7 +104,7 @@ class FusedSGD(Optimizer):
"weight_decay": weight_decay, "weight_decay": weight_decay,
"nesterov": nesterov, "nesterov": nesterov,
} }
if nesterov and (momentum <= 0 or dampening != 0): if nesterov and (momentum <= 0.0 or dampening != 0.0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening") raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super().__init__(params, defaults) super().__init__(params, defaults)
...@@ -106,7 +112,6 @@ class FusedSGD(Optimizer): ...@@ -106,7 +112,6 @@ class FusedSGD(Optimizer):
self.materialize_master_grads = materialize_master_grads self.materialize_master_grads = materialize_master_grads
self.most_recent_scale = 1.0 self.most_recent_scale = 1.0
self.scale_set_by_backward = False self.scale_set_by_backward = False
self.set_grad_none = set_grad_none
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.tensor( self._dummy_overflow_buf = torch.tensor(
...@@ -114,14 +119,42 @@ class FusedSGD(Optimizer): ...@@ -114,14 +119,42 @@ class FusedSGD(Optimizer):
) )
self.multi_tensor_sgd = tex.multi_tensor_sgd self.multi_tensor_sgd = tex.multi_tensor_sgd
# Deprecated options
self.set_grad_none = set_grad_none
if self.set_grad_none is not None:
warnings.warn(
"set_grad_none kwarg in FusedAdam constructor is deprecated. "
"Use set_to_none kwarg in zero_grad instead.",
DeprecationWarning,
)
def __setstate__(self, state): def __setstate__(self, state):
super().__setstate__(state) super().__setstate__(state)
for group in self.param_groups: for group in self.param_groups:
group.setdefault("nesterov", False) group.setdefault("nesterov", False)
def zero_grad(self): def zero_grad(self, set_to_none: Optional[bool] = None) -> None:
# pylint: disable=missing-function-docstring """Reset parameter gradients.
if self.set_grad_none:
Arguments:
set_to_none (bool, optional): whether to set grads to `None`
instead of zeroing out buffers. (default: True)
"""
# Handle deprecated set_grad_none option
if self.set_grad_none is not None:
if set_to_none is not None and set_to_none != self.set_grad_none:
raise ValueError(
f"Called zero_grad with set_to_none={set_to_none}, "
f"but FusedAdam was initialized with set_grad_none={self.set_grad_none}"
)
set_to_none = self.set_grad_none
if set_to_none is None:
set_to_none = True
# Reset grads
if set_to_none:
for group in self.param_groups: for group in self.param_groups:
for p in group["params"]: for p in group["params"]:
p.grad = None p.grad = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment