Commit c490bd36 authored by blisc's avatar blisc Committed by mcarilli
Browse files

Enable LARC for use with amp (#306)



* update larc
Signed-off-by: default avatarJason <jasoli@nvidia.com>

* scale_loss fix
Signed-off-by: default avatarJason <jasoli@nvidia.com>

* typo
Signed-off-by: default avatarJason <jasoli@nvidia.com>

* revert LARC
parent a5289067
......@@ -12,6 +12,7 @@ from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
from ..optimizers import FusedAdam
from ..parallel import DistributedDataParallel as apex_DDP
from ..parallel.LARC import LARC
def to_type(dtype, t):
......@@ -142,7 +143,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
from .amp import init as amp_init
optimizers_was_list = False
if isinstance(optimizers, torch.optim.Optimizer):
if isinstance(optimizers, torch.optim.Optimizer) or isinstance(optimizers, LARC):
optimizers = [optimizers]
elif optimizers is None:
optimizers = []
......
......@@ -8,6 +8,7 @@ from .scaler import LossScaler
from ._amp_state import _amp_state, master_params, maybe_print
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
from ..parallel.LARC import LARC
# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
......@@ -84,7 +85,7 @@ def scale_loss(loss,
yield loss
return
if isinstance(optimizers, torch.optim.Optimizer):
if isinstance(optimizers, torch.optim.Optimizer) or isinstance(optimizers, LARC):
optimizers = [optimizers]
# this is what happens when i have to support tools from different sources under the same API...
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment