# Copyright (c) 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. # #------------------------------------------------------------------------- # # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from . import FairseqOptimizer, register_optimizer from apex.optimizers.fused_adam import FusedAdam @register_optimizer('adam') class FairseqAdam(FairseqOptimizer): def __init__(self, args, params): super().__init__(args, params) self._optimizer = FusedAdam(params, **self.optimizer_config) @staticmethod def add_args(parser): """Add optimizer-specific arguments to the parser.""" parser.add_argument('--adam-betas', default=(0.9, 0.999), nargs=2, type=float, metavar='B1 B2', help='betas for Adam optimizer') parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D', help='epsilon for Adam optimizer') @property def optimizer_config(self): """ Return a kwarg dictionary that will be used to override optimizer args stored in checkpoints. This allows us to load a checkpoint and resume training using a different set of optimizer args, e.g., with a different learning rate. """ return { 'lr': self.args.lr[0], 'betas': self.args.adam_betas, 'eps': self.args.adam_eps, 'weight_decay': self.args.weight_decay, }