Commit 4b1f4788 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Misc fixes

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/439

Differential Revision: D13608151

Pulled By: myleott

fbshipit-source-id: 198b84995a6329f8329829cc91184d88f1eab947
parent 73876ce3
...@@ -162,6 +162,7 @@ class Trainer(object): ...@@ -162,6 +162,7 @@ class Trainer(object):
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
self.model.train() self.model.train()
self.criterion.train()
self.zero_grad() self.zero_grad()
if not dummy_batch: if not dummy_batch:
...@@ -286,6 +287,7 @@ class Trainer(object): ...@@ -286,6 +287,7 @@ class Trainer(object):
"""Do forward pass in evaluation mode.""" """Do forward pass in evaluation mode."""
with torch.no_grad(): with torch.no_grad():
self.model.eval() self.model.eval()
self.criterion.eval()
sample = self._prepare_sample(sample) sample = self._prepare_sample(sample)
if sample is None: if sample is None:
......
...@@ -375,8 +375,6 @@ if __name__ == '__main__': ...@@ -375,8 +375,6 @@ if __name__ == '__main__':
if args.distributed_init_method is not None: if args.distributed_init_method is not None:
# distributed training # distributed training
distributed_main(args.device_id, args) distributed_main(args.device_id, args)
args.distributed_rank = distributed_utils.distributed_init(args)
main(args)
elif args.distributed_world_size > 1: elif args.distributed_world_size > 1:
# fallback for single node with multiple GPUs # fallback for single node with multiple GPUs
port = random.randint(10000, 20000) port = random.randint(10000, 20000)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment