Commit bc8ae449 authored by Xian Li's avatar Xian Li Committed by Facebook Github Bot
Browse files

refactor AdversarialTrainer factor out helper functions

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/474

Reviewed By: theweiho, akinh

Differential Revision: D13701447

fbshipit-source-id: 34036dce7601835b605e3b169210edc7a6715de6
parent 3e67386b
......@@ -158,13 +158,7 @@ class Trainer(object):
def train_step(self, samples, dummy_batch=False):
"""Do forward, backward and parameter update."""
# Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints
seed = self.args.seed + self.get_num_updates()
torch.manual_seed(seed)
if self.cuda:
torch.cuda.manual_seed(seed)
self._set_seed()
self.model.train()
self.criterion.train()
self.zero_grad()
......@@ -395,3 +389,11 @@ class Trainer(object):
if self.cuda:
sample = utils.move_to_cuda(sample)
return sample
def _set_seed(self):
# Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints
seed = self.args.seed + self.get_num_updates()
torch.manual_seed(seed)
if self.cuda:
torch.cuda.manual_seed(seed)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment