Commit bc8ae449 authored by Xian Li's avatar Xian Li Committed by Facebook Github Bot
Browse files

refactor AdversarialTrainer factor out helper functions

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/474

Reviewed By: theweiho, akinh

Differential Revision: D13701447

fbshipit-source-id: 34036dce7601835b605e3b169210edc7a6715de6
parent 3e67386b
...@@ -158,13 +158,7 @@ class Trainer(object): ...@@ -158,13 +158,7 @@ class Trainer(object):
def train_step(self, samples, dummy_batch=False): def train_step(self, samples, dummy_batch=False):
"""Do forward, backward and parameter update.""" """Do forward, backward and parameter update."""
# Set seed based on args.seed and the update number so that we get self._set_seed()
# reproducible results when resuming from checkpoints
seed = self.args.seed + self.get_num_updates()
torch.manual_seed(seed)
if self.cuda:
torch.cuda.manual_seed(seed)
self.model.train() self.model.train()
self.criterion.train() self.criterion.train()
self.zero_grad() self.zero_grad()
...@@ -395,3 +389,11 @@ class Trainer(object): ...@@ -395,3 +389,11 @@ class Trainer(object):
if self.cuda: if self.cuda:
sample = utils.move_to_cuda(sample) sample = utils.move_to_cuda(sample)
return sample return sample
def _set_seed(self):
# Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints
seed = self.args.seed + self.get_num_updates()
torch.manual_seed(seed)
if self.cuda:
torch.cuda.manual_seed(seed)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment