Commit 29c82741 authored by Myle Ott's avatar Myle Ott Committed by Sergey Edunov
Browse files

Fix tests and flake8

parent b9f2d427
...@@ -39,6 +39,7 @@ def get_parser(): ...@@ -39,6 +39,7 @@ def get_parser():
parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary') parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary')
return parser return parser
def main(args): def main(args):
print(args) print(args)
os.makedirs(args.destdir, exist_ok=True) os.makedirs(args.destdir, exist_ok=True)
......
...@@ -11,12 +11,16 @@ import random ...@@ -11,12 +11,16 @@ import random
import sys import sys
import tempfile import tempfile
import unittest import unittest
from unittest import mock
import torch import torch
from fairseq import options from fairseq import options
import preprocess, train, generate, interactive import preprocess
import train
import generate
import interactive
class TestBinaries(unittest.TestCase): class TestBinaries(unittest.TestCase):
...@@ -82,6 +86,8 @@ class TestBinaries(unittest.TestCase): ...@@ -82,6 +86,8 @@ class TestBinaries(unittest.TestCase):
'--no-progress-bar', '--no-progress-bar',
], ],
) )
with mock.patch('train.torch.cuda.device_count') as device_count:
device_count.return_value = 1
train.main(train_args) train.main(train_args)
def generate(self, data_dir): def generate(self, data_dir):
......
...@@ -14,6 +14,7 @@ from distributed_train import main as distributed_main ...@@ -14,6 +14,7 @@ from distributed_train import main as distributed_main
from multiprocessing_train import main as multiprocessing_main from multiprocessing_train import main as multiprocessing_main
from singleprocess_train import main as singleprocess_main from singleprocess_train import main as singleprocess_main
def main(args): def main(args):
if args.distributed_port > 0 \ if args.distributed_port > 0 \
or args.distributed_init_method is not None: or args.distributed_init_method is not None:
...@@ -23,6 +24,7 @@ def main(args): ...@@ -23,6 +24,7 @@ def main(args):
else: else:
singleprocess_main(args) singleprocess_main(args)
if __name__ == '__main__': if __name__ == '__main__':
parser = options.get_training_parser() parser = options.get_training_parser()
args = options.parse_args_and_arch(parser) args = options.parse_args_and_arch(parser)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment