Commit 29c82741 authored by Myle Ott's avatar Myle Ott Committed by Sergey Edunov
Browse files

Fix tests and flake8

parent b9f2d427
......@@ -47,13 +47,13 @@ class Trainer(object):
self.meters['train_nll_loss'] = AverageMeter()
self.meters['valid_loss'] = AverageMeter()
self.meters['valid_nll_loss'] = AverageMeter()
self.meters['wps'] = TimeMeter() # words per second
self.meters['ups'] = TimeMeter() # updates per second
self.meters['wpb'] = AverageMeter() # words per batch
self.meters['bsz'] = AverageMeter() # sentences per batch
self.meters['gnorm'] = AverageMeter() # gradient norm
self.meters['clip'] = AverageMeter() # % of updates clipped
self.meters['oom'] = AverageMeter() # out of memory
self.meters['wps'] = TimeMeter() # words per second
self.meters['ups'] = TimeMeter() # updates per second
self.meters['wpb'] = AverageMeter() # words per batch
self.meters['bsz'] = AverageMeter() # sentences per batch
self.meters['gnorm'] = AverageMeter() # gradient norm
self.meters['clip'] = AverageMeter() # % of updates clipped
self.meters['oom'] = AverageMeter() # out of memory
self._max_bsz_seen = 0
self._num_updates = 0
......
......@@ -39,6 +39,7 @@ def get_parser():
parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary')
return parser
def main(args):
print(args)
os.makedirs(args.destdir, exist_ok=True)
......
......@@ -11,12 +11,16 @@ import random
import sys
import tempfile
import unittest
from unittest import mock
import torch
from fairseq import options
import preprocess, train, generate, interactive
import preprocess
import train
import generate
import interactive
class TestBinaries(unittest.TestCase):
......@@ -82,7 +86,9 @@ class TestBinaries(unittest.TestCase):
'--no-progress-bar',
],
)
train.main(train_args)
with mock.patch('train.torch.cuda.device_count') as device_count:
device_count.return_value = 1
train.main(train_args)
def generate(self, data_dir):
generate_parser = options.get_generation_parser()
......
......@@ -14,6 +14,7 @@ from distributed_train import main as distributed_main
from multiprocessing_train import main as multiprocessing_main
from singleprocess_train import main as singleprocess_main
def main(args):
if args.distributed_port > 0 \
or args.distributed_init_method is not None:
......@@ -23,6 +24,7 @@ def main(args):
else:
singleprocess_main(args)
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment