Commit cea12eaf authored by Maciej Torhan's avatar Maciej Torhan Committed by Facebook GitHub Bot
Browse files

Fix Adam and AdamW initializers in wav2letter example (#3145)

Summary:
In wav2letter example there is passed `momentum` to `Adam` and `AdamW` initializer, which is not a correct parameter. To fix that we need to add `beta_1` and `beta_2` to arguments and replace `momentum` with them. I also added `eps` similar to `Adadelta` initializer.

Pull Request resolved: https://github.com/pytorch/audio/pull/3145

Reviewed By: mthrok

Differential Revision: D43847713

Pulled By: nateanl

fbshipit-source-id: 94f7c48232fabf520cfce81471694cb545d160c6
parent 8a9ab2a4
...@@ -130,6 +130,8 @@ def parse_args(): ...@@ -130,6 +130,8 @@ def parse_args():
help="learning rate exponential decay constant", help="learning rate exponential decay constant",
) )
parser.add_argument("--momentum", default=0.8, type=float, metavar="M", help="momentum") parser.add_argument("--momentum", default=0.8, type=float, metavar="M", help="momentum")
parser.add_argument("--beta_1", default=0.9, type=float, metavar="BETA_1", help="beta_1")
parser.add_argument("--beta_2", default=0.999, type=float, metavar="BETA_2", help="beta_2")
parser.add_argument("--weight-decay", default=1e-5, type=float, metavar="W", help="weight decay") parser.add_argument("--weight-decay", default=1e-5, type=float, metavar="W", help="weight decay")
parser.add_argument("--eps", metavar="EPS", type=float, default=1e-8) parser.add_argument("--eps", metavar="EPS", type=float, default=1e-8)
parser.add_argument("--rho", metavar="RHO", type=float, default=0.95) parser.add_argument("--rho", metavar="RHO", type=float, default=0.95)
...@@ -472,15 +474,17 @@ def main(rank, args): ...@@ -472,15 +474,17 @@ def main(rank, args):
optimizer = Adam( optimizer = Adam(
model.parameters(), model.parameters(),
lr=args.learning_rate, lr=args.learning_rate,
momentum=args.momentum, betas=(args.beta_1, args.beta_2),
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
eps=args.eps,
) )
elif args.optimizer == "adamw": elif args.optimizer == "adamw":
optimizer = AdamW( optimizer = AdamW(
model.parameters(), model.parameters(),
lr=args.learning_rate, lr=args.learning_rate,
momentum=args.momentum, betas=(args.beta_1, args.beta_2),
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
eps=args.eps,
) )
else: else:
raise ValueError("Selected optimizer not supported") raise ValueError("Selected optimizer not supported")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment