".jenkins/git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "c3d4bfe8224967baef120be1ce6a8a0cc82c12ea"
Commit 81e1e248 authored by Li Li's avatar Li Li
Browse files

Fix optimizer to work with horovod

parent a2b6918a
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import math import math
import torch import torch
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
def warmup_cosine(x, warmup=0.002): def warmup_cosine(x, warmup=0.002):
...@@ -55,10 +56,10 @@ class BertAdam(Optimizer): ...@@ -55,10 +56,10 @@ class BertAdam(Optimizer):
weight_decay_rate: Weight decay. Default: 0.01 weight_decay_rate: Weight decay. Default: 0.01
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
""" """
def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear', def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01, b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01,
max_grad_norm=1.0): max_grad_norm=1.0):
if not lr >= 0.0: if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if schedule not in SCHEDULES: if schedule not in SCHEDULES:
raise ValueError("Invalid schedule parameter: {}".format(schedule)) raise ValueError("Invalid schedule parameter: {}".format(schedule))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment