Commit 16a72b4d authored by Myle Ott's avatar Myle Ott
Browse files

Add more integration tests (LM, stories, transformer, lstm)

parent 736fbee2
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import contextlib
from io import StringIO from io import StringIO
import os import os
import random import random
...@@ -20,23 +21,93 @@ import preprocess ...@@ -20,23 +21,93 @@ import preprocess
import train import train
import generate import generate
import interactive import interactive
import eval_lm
class TestTranslation(unittest.TestCase):
def test_fconv(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_fconv') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'fconv_iwslt_de_en')
generate_main(data_dir)
def test_fp16(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_fp16') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--fp16'])
generate_main(data_dir)
def test_update_freq(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_update_freq') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--update-freq', '3'])
generate_main(data_dir)
def test_lstm(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_lstm') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'lstm_wiseman_iwslt_de_en')
generate_main(data_dir)
def test_transformer(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_transformer') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'transformer_iwslt_de_en')
generate_main(data_dir)
class TestStories(unittest.TestCase):
def test_fconv_self_att_wp(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_fconv_self_att_wp') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
config = [
'--encoder-layers', '[(512, 3)] * 2',
'--decoder-layers', '[(512, 3)] * 2',
'--decoder-attention', 'True',
'--encoder-attention', 'False',
'--gated-attention', 'True',
'--self-attention', 'True',
'--project-input', 'True',
]
train_translation_model(data_dir, 'fconv_self_att_wp', config)
generate_main(data_dir)
# fusion model
os.rename(os.path.join(data_dir, 'checkpoint_last.pt'), os.path.join(data_dir, 'pretrained.pt'))
config.extend([
'--pretrained', 'True',
'--pretrained-checkpoint', os.path.join(data_dir, 'pretrained.pt'),
'--save-dir', os.path.join(data_dir, 'fusion_model'),
])
train_translation_model(data_dir, 'fconv_self_att_wp', config)
class TestBinaries(unittest.TestCase): class TestLanguageModeling(unittest.TestCase):
def test_binaries(self):
# comment this out to debug the unittest if it's failing
self.mock_stdout()
with tempfile.TemporaryDirectory() as data_dir: def test_fconv_lm(self):
self.create_dummy_data(data_dir) with contextlib.redirect_stdout(StringIO()):
self.preprocess_data(data_dir) with tempfile.TemporaryDirectory('test_fconv_lm') as data_dir:
self.train_model(data_dir) create_dummy_data(data_dir)
self.generate(data_dir) preprocess_lm_data(data_dir)
train_language_model(data_dir, 'fconv_lm')
eval_lm_main(data_dir)
self.unmock_stdout()
def create_dummy_data(self, data_dir, num_examples=1000, maxlen=20): def create_dummy_data(data_dir, num_examples=1000, maxlen=20):
def _create_dummy_data(filename): def _create_dummy_data(filename):
data = torch.rand(num_examples * maxlen) data = torch.rand(num_examples * maxlen)
...@@ -56,7 +127,8 @@ class TestBinaries(unittest.TestCase): ...@@ -56,7 +127,8 @@ class TestBinaries(unittest.TestCase):
_create_dummy_data('test.in') _create_dummy_data('test.in')
_create_dummy_data('test.out') _create_dummy_data('test.out')
def preprocess_data(self, data_dir):
def preprocess_translation_data(data_dir):
preprocess_parser = preprocess.get_parser() preprocess_parser = preprocess.get_parser()
preprocess_args = preprocess_parser.parse_args([ preprocess_args = preprocess_parser.parse_args([
'--source-lang', 'in', '--source-lang', 'in',
...@@ -70,33 +142,36 @@ class TestBinaries(unittest.TestCase): ...@@ -70,33 +142,36 @@ class TestBinaries(unittest.TestCase):
]) ])
preprocess.main(preprocess_args) preprocess.main(preprocess_args)
def train_model(self, data_dir):
def train_translation_model(data_dir, arch, extra_flags=None):
train_parser = options.get_training_parser() train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch( train_args = options.parse_args_and_arch(
train_parser, train_parser,
[ [
data_dir, data_dir,
'--arch', 'fconv_iwslt_de_en', '--save-dir', data_dir,
'--arch', arch,
'--optimizer', 'nag', '--optimizer', 'nag',
'--lr', '0.05', '--lr', '0.05',
'--max-tokens', '500', '--max-tokens', '500',
'--save-dir', data_dir,
'--max-epoch', '1', '--max-epoch', '1',
'--no-progress-bar', '--no-progress-bar',
'--distributed-world-size', '1', '--distributed-world-size', '1',
'--source-lang', 'in', '--source-lang', 'in',
'--target-lang', 'out', '--target-lang', 'out',
], ] + (extra_flags or []),
) )
train.main(train_args) train.main(train_args)
def generate(self, data_dir):
def generate_main(data_dir):
generate_parser = options.get_generation_parser() generate_parser = options.get_generation_parser()
generate_args = generate_parser.parse_args([ generate_args = generate_parser.parse_args([
data_dir, data_dir,
'--path', os.path.join(data_dir, 'checkpoint_best.pt'), '--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--beam', '5', '--beam', '3',
'--batch-size', '32', '--batch-size', '64',
'--max-len-b', '5',
'--gen-subset', 'valid', '--gen-subset', 'valid',
'--no-progress-bar', '--no-progress-bar',
]) ])
...@@ -112,13 +187,51 @@ class TestBinaries(unittest.TestCase): ...@@ -112,13 +187,51 @@ class TestBinaries(unittest.TestCase):
interactive.main(generate_args) interactive.main(generate_args)
sys.stdin = orig_stdin sys.stdin = orig_stdin
def mock_stdout(self):
self._orig_stdout = sys.stdout
sys.stdout = StringIO()
def unmock_stdout(self): def preprocess_lm_data(data_dir):
if hasattr(self, '_orig_stdout'): preprocess_parser = preprocess.get_parser()
sys.stdout = self._orig_stdout preprocess_args = preprocess_parser.parse_args([
'--only-source',
'--trainpref', os.path.join(data_dir, 'train.out'),
'--validpref', os.path.join(data_dir, 'valid.out'),
'--testpref', os.path.join(data_dir, 'test.out'),
'--destdir', data_dir,
])
preprocess.main(preprocess_args)
def train_language_model(data_dir, arch):
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
[
data_dir,
'--arch', arch,
'--optimizer', 'nag',
'--lr', '1.0',
'--criterion', 'adaptive_loss',
'--adaptive-softmax-cutoff', '5,10,15',
'--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
'--decoder-embed-dim', '280',
'--max-tokens', '500',
'--max-target-positions', '500',
'--save-dir', data_dir,
'--max-epoch', '1',
'--no-progress-bar',
'--distributed-world-size', '1',
],
)
train.main(train_args)
def eval_lm_main(data_dir):
eval_lm_parser = options.get_eval_lm_parser()
eval_lm_args = eval_lm_parser.parse_args([
data_dir,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--no-progress-bar',
])
eval_lm.main(eval_lm_args)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment