Unverified Commit b4fbc2b3 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Use same API in optimizer `zero_grad` as PyTorch optimizers (#1466)



Use same API in optimizer zero_grad as PyT optimizers
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 257345a5
......@@ -3,8 +3,12 @@
# See LICENSE for license information.
"""Fused Adam optimizer."""
from __future__ import annotations
from collections.abc import Iterable
from copy import deepcopy
from itertools import chain
from typing import Optional
import warnings
import torch
import transformer_engine_torch as tex
......@@ -52,8 +56,6 @@ class FusedAdam(torch.optim.Optimizer):
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
bias_correction (bool, optional): apply correction factor to
moment estimates. (default: True)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
......@@ -62,10 +64,10 @@ class FusedAdam(torch.optim.Optimizer):
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
bias_correction (bool, optional): apply correction factor to
moment estimates. (default: True)
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
capturable (bool, optional): whether to use the version of the optimizer
that can be used with CUDA Graphs. (default: False)
master_weights (bool, optional): whether to maintain FP32 master weights
......@@ -106,15 +108,15 @@ class FusedAdam(torch.optim.Optimizer):
def __init__(
self,
params,
lr=1e-3,
params: Iterable[torch.nn.Parameter | dict],
lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
amsgrad: bool = False,
*,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
adam_w_mode=True,
weight_decay=0.0,
amsgrad=False,
set_grad_none=True,
capturable=False,
master_weights=False,
master_weight_dtype=torch.float32,
......@@ -122,6 +124,7 @@ class FusedAdam(torch.optim.Optimizer):
exp_avg_sq_dtype=torch.float32,
use_decoupled_grad=False,
store_param_remainders=False,
set_grad_none: Optional[bool] = None, # deprecated
):
if amsgrad:
......@@ -160,7 +163,6 @@ class FusedAdam(torch.optim.Optimizer):
}
super().__init__(params, defaults)
self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none
self.capturable = capturable
self.master_weights = master_weights
......@@ -204,19 +206,46 @@ class FusedAdam(torch.optim.Optimizer):
store_param_remainders and master_weights and master_weight_dtype == torch.float32
)
def zero_grad(self):
# pylint: disable=missing-function-docstring
if not self.use_decoupled_grad and not self.set_grad_none:
super().zero_grad()
# Deprecated options
self.set_grad_none = set_grad_none
if self.set_grad_none is not None:
warnings.warn(
"set_grad_none kwarg in FusedAdam constructor is deprecated. "
"Use set_to_none kwarg in zero_grad instead.",
DeprecationWarning,
)
def zero_grad(self, set_to_none: Optional[bool] = None) -> None:
"""Reset parameter gradients.
Arguments:
set_to_none (bool, optional): whether to set grads to `None`
instead of zeroing out buffers. (default: True)
"""
# Handle deprecated set_grad_none option
if self.set_grad_none is not None:
if set_to_none is not None and set_to_none != self.set_grad_none:
raise ValueError(
f"Called zero_grad with set_to_none={set_to_none}, "
f"but FusedAdam was initialized with set_grad_none={self.set_grad_none}"
)
set_to_none = self.set_grad_none
if set_to_none is None:
set_to_none = True
if not self.use_decoupled_grad and not set_to_none:
super().zero_grad(set_to_none=set_to_none)
return
for group in self.param_groups:
for p in group["params"]:
if self.use_decoupled_grad and self.set_grad_none:
if self.use_decoupled_grad and set_to_none:
p.decoupled_grad = None
elif self.use_decoupled_grad and not self.set_grad_none:
elif self.use_decoupled_grad and not set_to_none:
p.decoupled_grad.zero_()
elif not self.use_decoupled_grad and self.set_grad_none:
elif not self.use_decoupled_grad and set_to_none:
p.grad = None
def _apply_scale(self, state_name, unscaled_state, scaled_state, scale):
......
......@@ -3,6 +3,11 @@
# See LICENSE for license information.
"""Fused SGD optimizer."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional
import warnings
import torch
from torch.optim.optimizer import Optimizer, required
......@@ -37,8 +42,8 @@ class FusedSGD(Optimizer):
parameter groups
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
Example:
......@@ -74,15 +79,16 @@ class FusedSGD(Optimizer):
def __init__(
self,
params,
lr=required,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
params: Iterable[torch.nn.Parameter | dict],
lr: float | Any = required,
momentum: float = 0.0,
dampening: float = 0.0,
weight_decay: float = 0.0,
nesterov: bool = False,
*,
wd_after_momentum=False,
materialize_master_grads=True,
set_grad_none=False,
set_grad_none: Optional[bool] = None, # deprecated
):
if lr is not required and lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
......@@ -98,7 +104,7 @@ class FusedSGD(Optimizer):
"weight_decay": weight_decay,
"nesterov": nesterov,
}
if nesterov and (momentum <= 0 or dampening != 0):
if nesterov and (momentum <= 0.0 or dampening != 0.0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super().__init__(params, defaults)
......@@ -106,7 +112,6 @@ class FusedSGD(Optimizer):
self.materialize_master_grads = materialize_master_grads
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
self.set_grad_none = set_grad_none
# Skip buffer
self._dummy_overflow_buf = torch.tensor(
......@@ -114,14 +119,42 @@ class FusedSGD(Optimizer):
)
self.multi_tensor_sgd = tex.multi_tensor_sgd
# Deprecated options
self.set_grad_none = set_grad_none
if self.set_grad_none is not None:
warnings.warn(
"set_grad_none kwarg in FusedAdam constructor is deprecated. "
"Use set_to_none kwarg in zero_grad instead.",
DeprecationWarning,
)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("nesterov", False)
def zero_grad(self):
# pylint: disable=missing-function-docstring
if self.set_grad_none:
def zero_grad(self, set_to_none: Optional[bool] = None) -> None:
"""Reset parameter gradients.
Arguments:
set_to_none (bool, optional): whether to set grads to `None`
instead of zeroing out buffers. (default: True)
"""
# Handle deprecated set_grad_none option
if self.set_grad_none is not None:
if set_to_none is not None and set_to_none != self.set_grad_none:
raise ValueError(
f"Called zero_grad with set_to_none={set_to_none}, "
f"but FusedAdam was initialized with set_grad_none={self.set_grad_none}"
)
set_to_none = self.set_grad_none
if set_to_none is None:
set_to_none = True
# Reset grads
if set_to_none:
for group in self.param_groups:
for p in group["params"]:
p.grad = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment