Commit 0d90e35f authored by Myle Ott's avatar Myle Ott Committed by Sergey Edunov
Browse files

More unit test fixes

parent 29c82741
......@@ -7,6 +7,8 @@
import argparse
import torch
from fairseq.criterions import CRITERION_REGISTRY
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
from fairseq.optim import OPTIMIZER_REGISTRY
......@@ -117,8 +119,9 @@ def add_dataset_args(parser, train=False, gen=False):
def add_distributed_training_args(parser):
group = parser.add_argument_group('Distributed training')
group.add_argument('--distributed-world-size', default=1, type=int, metavar='N',
help='total number of GPUs across all nodes, default: 1 GPU')
group.add_argument('--distributed-world-size', type=int, metavar='N',
default=torch.cuda.device_count(),
help='total number of GPUs across all nodes (default: all visible GPUs)')
group.add_argument('--distributed-rank', default=0, type=int,
help='rank of the current worker')
group.add_argument('--distributed-backend', default='nccl', type=str,
......
......@@ -90,6 +90,7 @@ class SequenceGenerator(object):
for model in self.models:
if isinstance(model.decoder, FairseqIncrementalDecoder):
stack.enter_context(model.decoder.incremental_inference())
with utils.maybe_no_grad():
return self._generate(src_tokens, src_lengths, beam_size, maxlen)
def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None):
......
......@@ -18,8 +18,6 @@ def main(args):
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu
if hasattr(torch, 'set_grad_enabled'):
torch.set_grad_enabled(False)
# Load dataset
if args.replace_unk is None:
......
......@@ -18,8 +18,6 @@ def main(args):
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu
if hasattr(torch, 'set_grad_enabled'):
torch.set_grad_enabled(False)
# Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path)))
......
......@@ -11,7 +11,6 @@ import random
import sys
import tempfile
import unittest
from unittest import mock
import torch
......@@ -84,10 +83,9 @@ class TestBinaries(unittest.TestCase):
'--save-dir', data_dir,
'--max-epoch', '1',
'--no-progress-bar',
'--distributed-world-size', '1',
],
)
with mock.patch('train.torch.cuda.device_count') as device_count:
device_count.return_value = 1
train.main(train_args)
def generate(self, data_dir):
......
......@@ -19,7 +19,7 @@ def main(args):
if args.distributed_port > 0 \
or args.distributed_init_method is not None:
distributed_main(args)
elif torch.cuda.device_count() > 1:
elif args.distributed_world_size > 1:
multiprocessing_main(args)
else:
singleprocess_main(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment