Commit 0d90e35f authored by Myle Ott's avatar Myle Ott Committed by Sergey Edunov
Browse files

More unit test fixes

parent 29c82741
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
import argparse import argparse
import torch
from fairseq.criterions import CRITERION_REGISTRY from fairseq.criterions import CRITERION_REGISTRY
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
from fairseq.optim import OPTIMIZER_REGISTRY from fairseq.optim import OPTIMIZER_REGISTRY
...@@ -117,8 +119,9 @@ def add_dataset_args(parser, train=False, gen=False): ...@@ -117,8 +119,9 @@ def add_dataset_args(parser, train=False, gen=False):
def add_distributed_training_args(parser): def add_distributed_training_args(parser):
group = parser.add_argument_group('Distributed training') group = parser.add_argument_group('Distributed training')
group.add_argument('--distributed-world-size', default=1, type=int, metavar='N', group.add_argument('--distributed-world-size', type=int, metavar='N',
help='total number of GPUs across all nodes, default: 1 GPU') default=torch.cuda.device_count(),
help='total number of GPUs across all nodes (default: all visible GPUs)')
group.add_argument('--distributed-rank', default=0, type=int, group.add_argument('--distributed-rank', default=0, type=int,
help='rank of the current worker') help='rank of the current worker')
group.add_argument('--distributed-backend', default='nccl', type=str, group.add_argument('--distributed-backend', default='nccl', type=str,
......
...@@ -90,6 +90,7 @@ class SequenceGenerator(object): ...@@ -90,6 +90,7 @@ class SequenceGenerator(object):
for model in self.models: for model in self.models:
if isinstance(model.decoder, FairseqIncrementalDecoder): if isinstance(model.decoder, FairseqIncrementalDecoder):
stack.enter_context(model.decoder.incremental_inference()) stack.enter_context(model.decoder.incremental_inference())
with utils.maybe_no_grad():
return self._generate(src_tokens, src_lengths, beam_size, maxlen) return self._generate(src_tokens, src_lengths, beam_size, maxlen)
def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None): def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None):
......
...@@ -18,8 +18,6 @@ def main(args): ...@@ -18,8 +18,6 @@ def main(args):
print(args) print(args)
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
if hasattr(torch, 'set_grad_enabled'):
torch.set_grad_enabled(False)
# Load dataset # Load dataset
if args.replace_unk is None: if args.replace_unk is None:
......
...@@ -18,8 +18,6 @@ def main(args): ...@@ -18,8 +18,6 @@ def main(args):
print(args) print(args)
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
if hasattr(torch, 'set_grad_enabled'):
torch.set_grad_enabled(False)
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path))) print('| loading model(s) from {}'.format(', '.join(args.path)))
......
...@@ -11,7 +11,6 @@ import random ...@@ -11,7 +11,6 @@ import random
import sys import sys
import tempfile import tempfile
import unittest import unittest
from unittest import mock
import torch import torch
...@@ -84,10 +83,9 @@ class TestBinaries(unittest.TestCase): ...@@ -84,10 +83,9 @@ class TestBinaries(unittest.TestCase):
'--save-dir', data_dir, '--save-dir', data_dir,
'--max-epoch', '1', '--max-epoch', '1',
'--no-progress-bar', '--no-progress-bar',
'--distributed-world-size', '1',
], ],
) )
with mock.patch('train.torch.cuda.device_count') as device_count:
device_count.return_value = 1
train.main(train_args) train.main(train_args)
def generate(self, data_dir): def generate(self, data_dir):
......
...@@ -19,7 +19,7 @@ def main(args): ...@@ -19,7 +19,7 @@ def main(args):
if args.distributed_port > 0 \ if args.distributed_port > 0 \
or args.distributed_init_method is not None: or args.distributed_init_method is not None:
distributed_main(args) distributed_main(args)
elif torch.cuda.device_count() > 1: elif args.distributed_world_size > 1:
multiprocessing_main(args) multiprocessing_main(args)
else: else:
singleprocess_main(args) singleprocess_main(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment