Commit 397ba265 authored by Taylan Bilal's avatar Taylan Bilal Committed by Facebook Github Bot
Browse files

adding __getstate__ to fairseq_optimizer (#872)

Summary:
self._optimizer has __getstate__
We need this so that fairseq_optimizer's work with pytorch/xla

```
% find . | xargs grep -s -i __getstate__
./third_party/tensorflow/tensorflow/python/util/deprecation_wrapper.py:  def __getstate__(self):
./torch_xla_py/xla_model.py:  for param_group in optimizer.__getstate__()['param_groups']:
```
Pull Request resolved: https://github.com/pytorch/fairseq/pull/872

Differential Revision: D16211062

Pulled By: alexeib

fbshipit-source-id: 1b5575c85d34b7b021d719a03fd58d1c2ee453ee
parent 0a4911f6
...@@ -41,6 +41,9 @@ class FairseqOptimizer(object): ...@@ -41,6 +41,9 @@ class FairseqOptimizer(object):
""" """
raise NotImplementedError raise NotImplementedError
def __getstate__(self):
return self._optimizer.__getstate__()
def get_lr(self): def get_lr(self):
"""Return the current learning rate.""" """Return the current learning rate."""
return self.optimizer.param_groups[0]['lr'] return self.optimizer.param_groups[0]['lr']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment