Commit 7176667d authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix apex Adam to not break CPU mode

Reviewed By: chenyangyu1988

Differential Revision: D14784219

fbshipit-source-id: 273888d6e3d22a01d5e7edfbc786195e7b78efef
parent 437c2386
...@@ -14,13 +14,17 @@ from . import FairseqOptimizer, register_optimizer ...@@ -14,13 +14,17 @@ from . import FairseqOptimizer, register_optimizer
@register_optimizer('adam') @register_optimizer('adam')
class FairseqAdam(FairseqOptimizer): class FairseqAdam(FairseqOptimizer):
def __init__(self, args, params): def __init__(self, args, params):
super().__init__(args, params) super().__init__(args, params)
if torch.cuda.is_available():
try: try:
from apex.optimizers import FusedAdam from apex.optimizers import FusedAdam
self._optimizer = FusedAdam(params, **self.optimizer_config) self._optimizer = FusedAdam(params, **self.optimizer_config)
except ImportError: except ImportError:
self._optimizer = Adam(params, **self.optimizer_config) self._optimizer = Adam(params, **self.optimizer_config)
else:
self._optimizer = Adam(params, **self.optimizer_config)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment