Commit 86857a58 authored by Changhan Wang's avatar Changhan Wang Committed by Facebook Github Bot
Browse files

Levenshtein Transformer paper code

Summary:
Code for our NeurIPS paper [Levenshtein Transformer](https://arxiv.org/abs/1905.11006)
* Added Levenshtein Transformer model, task and criterion class
* Added iterative NAT Transformer, insertion Transformer and CMLM Transformer model class for baselines
* Add an option for prepending BOS to dictionary class and translation task class

Reviewed By: myleott

Differential Revision: D17297372

fbshipit-source-id: 54eca60831ae95dc721c2c34e882e1810ee575c7
parent 6c1da0f7
......@@ -359,3 +359,11 @@ def has_parameters(module):
return True
except StopIteration:
return False
def set_torch_seed(seed):
# Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints
assert isinstance(seed, int)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
......@@ -159,6 +159,9 @@ def main(args):
' '.join(map(lambda x: str(utils.item(x)), alignment))
))
if args.print_step:
print('I-{}\t{}'.format(sample_id, hypo['steps']))
# Score only the top hypothesis
if has_target and j == 0:
if align_dict is not None or args.remove_bpe is not None:
......
......@@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
from setuptools import setup, find_packages, Extension
from torch.utils import cpp_extension
import sys
......@@ -60,6 +61,12 @@ extensions = [
language='c++',
extra_compile_args=extra_compile_args,
),
cpp_extension.CppExtension(
'fairseq.libnat',
sources=[
'fairseq/clib/libnat/edit_dist.cpp',
],
)
]
......@@ -106,5 +113,6 @@ setup(
'fairseq-validate = fairseq_cli.validate:cli_main',
],
},
cmdclass={'build_ext': cpp_extension.BuildExtension},
zip_safe=False,
)
......@@ -180,6 +180,52 @@ class TestTranslation(unittest.TestCase):
])
generate_main(data_dir)
def test_levenshtein_transformer(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_levenshtein_transformer') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'levenshtein_transformer', [
'--apply-bert-init', '--early-exit', '6,6,6',
'--criterion', 'nat_loss'
], task='translation_lev')
generate_main(data_dir)
def test_nonautoregressive_transformer(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_nonautoregressive_transformer') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'nonautoregressive_transformer', [
'--apply-bert-init', '--src-embedding-copy', '--criterion',
'nat_loss', '--noise', 'full_mask', '--pred-length-offset',
'--length-loss-factor', '0.1'
], task='translation_lev')
generate_main(data_dir)
def test_iterative_nonautoregressive_transformer(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_iterative_nonautoregressive_transformer') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'iterative_nonautoregressive_transformer', [
'--apply-bert-init', '--src-embedding-copy', '--criterion',
'nat_loss', '--noise', 'full_mask', '--stochastic-approx',
'--dae-ratio', '0.5', '--train-step', '3'
], task='translation_lev')
generate_main(data_dir)
def test_insertion_transformer(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_insertion_transformer') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'insertion_transformer', [
'--apply-bert-init', '--criterion', 'nat_loss', '--noise',
'random_mask'
], task='translation_lev')
generate_main(data_dir)
def test_mixture_of_experts(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_moe') as data_dir:
......
......@@ -194,6 +194,11 @@ def get_training_stats(trainer):
def validate(args, trainer, task, epoch_itr, subsets):
"""Evaluate the model on the validation set(s) and return the losses."""
if args.fixed_validation_seed is not None:
# set fixed seed for every validation
utils.set_torch_seed(args.fixed_validation_seed)
valid_losses = []
for subset in subsets:
# Initialize data iterator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment