"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "720dbfc985de51de623327ff3deefd37a77c808c"
Commit 16a72b4d authored by Myle Ott's avatar Myle Ott
Browse files

Add more integration tests (LM, stories, transformer, lstm)

parent 736fbee2
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import contextlib
from io import StringIO from io import StringIO
import os import os
import random import random
...@@ -20,105 +21,217 @@ import preprocess ...@@ -20,105 +21,217 @@ import preprocess
import train import train
import generate import generate
import interactive import interactive
import eval_lm
class TestBinaries(unittest.TestCase):
class TestTranslation(unittest.TestCase):
def test_binaries(self):
# comment this out to debug the unittest if it's failing def test_fconv(self):
self.mock_stdout() with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_fconv') as data_dir:
with tempfile.TemporaryDirectory() as data_dir: create_dummy_data(data_dir)
self.create_dummy_data(data_dir) preprocess_translation_data(data_dir)
self.preprocess_data(data_dir) train_translation_model(data_dir, 'fconv_iwslt_de_en')
self.train_model(data_dir) generate_main(data_dir)
self.generate(data_dir)
def test_fp16(self):
self.unmock_stdout() with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_fp16') as data_dir:
def create_dummy_data(self, data_dir, num_examples=1000, maxlen=20): create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
def _create_dummy_data(filename): train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--fp16'])
data = torch.rand(num_examples * maxlen) generate_main(data_dir)
data = 97 + torch.floor(26 * data).int()
with open(os.path.join(data_dir, filename), 'w') as h: def test_update_freq(self):
offset = 0 with contextlib.redirect_stdout(StringIO()):
for _ in range(num_examples): with tempfile.TemporaryDirectory('test_update_freq') as data_dir:
ex_len = random.randint(1, maxlen) create_dummy_data(data_dir)
ex_str = ' '.join(map(chr, data[offset:offset+ex_len])) preprocess_translation_data(data_dir)
print(ex_str, file=h) train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--update-freq', '3'])
offset += ex_len generate_main(data_dir)
_create_dummy_data('train.in') def test_lstm(self):
_create_dummy_data('train.out') with contextlib.redirect_stdout(StringIO()):
_create_dummy_data('valid.in') with tempfile.TemporaryDirectory('test_lstm') as data_dir:
_create_dummy_data('valid.out') create_dummy_data(data_dir)
_create_dummy_data('test.in') preprocess_translation_data(data_dir)
_create_dummy_data('test.out') train_translation_model(data_dir, 'lstm_wiseman_iwslt_de_en')
generate_main(data_dir)
def preprocess_data(self, data_dir):
preprocess_parser = preprocess.get_parser() def test_transformer(self):
preprocess_args = preprocess_parser.parse_args([ with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_transformer') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'transformer_iwslt_de_en')
generate_main(data_dir)
class TestStories(unittest.TestCase):
def test_fconv_self_att_wp(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_fconv_self_att_wp') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
config = [
'--encoder-layers', '[(512, 3)] * 2',
'--decoder-layers', '[(512, 3)] * 2',
'--decoder-attention', 'True',
'--encoder-attention', 'False',
'--gated-attention', 'True',
'--self-attention', 'True',
'--project-input', 'True',
]
train_translation_model(data_dir, 'fconv_self_att_wp', config)
generate_main(data_dir)
# fusion model
os.rename(os.path.join(data_dir, 'checkpoint_last.pt'), os.path.join(data_dir, 'pretrained.pt'))
config.extend([
'--pretrained', 'True',
'--pretrained-checkpoint', os.path.join(data_dir, 'pretrained.pt'),
'--save-dir', os.path.join(data_dir, 'fusion_model'),
])
train_translation_model(data_dir, 'fconv_self_att_wp', config)
class TestLanguageModeling(unittest.TestCase):
def test_fconv_lm(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_fconv_lm') as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_language_model(data_dir, 'fconv_lm')
eval_lm_main(data_dir)
def create_dummy_data(data_dir, num_examples=1000, maxlen=20):
def _create_dummy_data(filename):
data = torch.rand(num_examples * maxlen)
data = 97 + torch.floor(26 * data).int()
with open(os.path.join(data_dir, filename), 'w') as h:
offset = 0
for _ in range(num_examples):
ex_len = random.randint(1, maxlen)
ex_str = ' '.join(map(chr, data[offset:offset+ex_len]))
print(ex_str, file=h)
offset += ex_len
_create_dummy_data('train.in')
_create_dummy_data('train.out')
_create_dummy_data('valid.in')
_create_dummy_data('valid.out')
_create_dummy_data('test.in')
_create_dummy_data('test.out')
def preprocess_translation_data(data_dir):
preprocess_parser = preprocess.get_parser()
preprocess_args = preprocess_parser.parse_args([
'--source-lang', 'in',
'--target-lang', 'out',
'--trainpref', os.path.join(data_dir, 'train'),
'--validpref', os.path.join(data_dir, 'valid'),
'--testpref', os.path.join(data_dir, 'test'),
'--thresholdtgt', '0',
'--thresholdsrc', '0',
'--destdir', data_dir,
])
preprocess.main(preprocess_args)
def train_translation_model(data_dir, arch, extra_flags=None):
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
[
data_dir,
'--save-dir', data_dir,
'--arch', arch,
'--optimizer', 'nag',
'--lr', '0.05',
'--max-tokens', '500',
'--max-epoch', '1',
'--no-progress-bar',
'--distributed-world-size', '1',
'--source-lang', 'in', '--source-lang', 'in',
'--target-lang', 'out', '--target-lang', 'out',
'--trainpref', os.path.join(data_dir, 'train'), ] + (extra_flags or []),
'--validpref', os.path.join(data_dir, 'valid'), )
'--testpref', os.path.join(data_dir, 'test'), train.main(train_args)
'--thresholdtgt', '0',
'--thresholdsrc', '0',
'--destdir', data_dir, def generate_main(data_dir):
]) generate_parser = options.get_generation_parser()
preprocess.main(preprocess_args) generate_args = generate_parser.parse_args([
data_dir,
def train_model(self, data_dir): '--path', os.path.join(data_dir, 'checkpoint_last.pt'),
train_parser = options.get_training_parser() '--beam', '3',
train_args = options.parse_args_and_arch( '--batch-size', '64',
train_parser, '--max-len-b', '5',
[ '--gen-subset', 'valid',
data_dir, '--no-progress-bar',
'--arch', 'fconv_iwslt_de_en', ])
'--optimizer', 'nag',
'--lr', '0.05', # evaluate model in batch mode
'--max-tokens', '500', generate.main(generate_args)
'--save-dir', data_dir,
'--max-epoch', '1', # evaluate model interactively
'--no-progress-bar', generate_args.buffer_size = 0
'--distributed-world-size', '1', generate_args.max_sentences = None
'--source-lang', 'in', orig_stdin = sys.stdin
'--target-lang', 'out', sys.stdin = StringIO('h e l l o\n')
], interactive.main(generate_args)
) sys.stdin = orig_stdin
train.main(train_args)
def generate(self, data_dir): def preprocess_lm_data(data_dir):
generate_parser = options.get_generation_parser() preprocess_parser = preprocess.get_parser()
generate_args = generate_parser.parse_args([ preprocess_args = preprocess_parser.parse_args([
'--only-source',
'--trainpref', os.path.join(data_dir, 'train.out'),
'--validpref', os.path.join(data_dir, 'valid.out'),
'--testpref', os.path.join(data_dir, 'test.out'),
'--destdir', data_dir,
])
preprocess.main(preprocess_args)
def train_language_model(data_dir, arch):
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
[
data_dir, data_dir,
'--path', os.path.join(data_dir, 'checkpoint_best.pt'), '--arch', arch,
'--beam', '5', '--optimizer', 'nag',
'--batch-size', '32', '--lr', '1.0',
'--gen-subset', 'valid', '--criterion', 'adaptive_loss',
'--adaptive-softmax-cutoff', '5,10,15',
'--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
'--decoder-embed-dim', '280',
'--max-tokens', '500',
'--max-target-positions', '500',
'--save-dir', data_dir,
'--max-epoch', '1',
'--no-progress-bar', '--no-progress-bar',
]) '--distributed-world-size', '1',
],
# evaluate model in batch mode )
generate.main(generate_args) train.main(train_args)
# evaluate model interactively
generate_args.buffer_size = 0 def eval_lm_main(data_dir):
generate_args.max_sentences = None eval_lm_parser = options.get_eval_lm_parser()
orig_stdin = sys.stdin eval_lm_args = eval_lm_parser.parse_args([
sys.stdin = StringIO('h e l l o\n') data_dir,
interactive.main(generate_args) '--path', os.path.join(data_dir, 'checkpoint_last.pt'),
sys.stdin = orig_stdin '--no-progress-bar',
])
def mock_stdout(self): eval_lm.main(eval_lm_args)
self._orig_stdout = sys.stdout
sys.stdout = StringIO()
def unmock_stdout(self):
if hasattr(self, '_orig_stdout'):
sys.stdout = self._orig_stdout
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment