"vscode:/vscode.git/clone" did not exist on "31d53367671d2deed42318a488a8d66158d7fbbe"
Commit e286243c authored by Liezl Puzon's avatar Liezl Puzon Committed by Facebook Github Bot
Browse files

Add denoising dataset for denoising autoencoder (#306)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/306

This uses a source dataset to generate a batch of {source: noisy source, target: original clean source} which allows us to train a denoising autoencoding component as part of a seq2seq model.

Reviewed By: xianxl

Differential Revision: D10078981

fbshipit-source-id: 026225984d4a97062ac05dc3a36e79b5c841fe9c
parent 8798a240
...@@ -9,6 +9,7 @@ from .dictionary import Dictionary, TruncatedDictionary ...@@ -9,6 +9,7 @@ from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset from .fairseq_dataset import FairseqDataset
from .concat_dataset import ConcatDataset from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
from .append_eos_dataset import AppendEosDataset
from .language_pair_dataset import LanguagePairDataset from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset from .monolingual_dataset import MonolingualDataset
from .token_block_dataset import TokenBlockDataset from .token_block_dataset import TokenBlockDataset
...@@ -21,6 +22,7 @@ from .iterators import ( ...@@ -21,6 +22,7 @@ from .iterators import (
) )
__all__ = [ __all__ = [
'AppendEosDataset',
'ConcatDataset', 'ConcatDataset',
'CountingIterator', 'CountingIterator',
'Dictionary', 'Dictionary',
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
class AppendEosDataset(torch.utils.data.Dataset):
"""A dataset wrapper that appends EOS to each item."""
def __init__(self, dataset, eos):
self.dataset = dataset
self.eos = eos
def __getitem__(self, index):
item = torch.cat([self.dataset[index], torch.LongTensor([self.eos])])
print(item)
return item
def __len__(self):
return len(self.dataset)
...@@ -31,7 +31,9 @@ class BacktranslationDataset(FairseqDataset): ...@@ -31,7 +31,9 @@ class BacktranslationDataset(FairseqDataset):
Args: Args:
tgt_dataset: dataset which will be used to build self.tgt_dataset -- tgt_dataset: dataset which will be used to build self.tgt_dataset --
a LanguagePairDataset with tgt dataset as the source dataset and a LanguagePairDataset with tgt dataset as the source dataset and
None as the target dataset. None as the target dataset. Should NOT have padding so that
src_lengths are accurately calculated by language_pair_dataset
collate function.
We use language_pair_dataset here to encapsulate the tgt_dataset We use language_pair_dataset here to encapsulate the tgt_dataset
so we can re-use the LanguagePairDataset collater to format the so we can re-use the LanguagePairDataset collater to format the
batches in the structure that SequenceGenerator expects. batches in the structure that SequenceGenerator expects.
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
import torch import torch
import numpy as np import numpy as np
from fairseq.data import data_utils
class WordNoising(object): class WordNoising(object):
"""Generate a noisy version of a sentence, without changing words themselves.""" """Generate a noisy version of a sentence, without changing words themselves."""
...@@ -150,3 +152,116 @@ class WordShuffle(WordNoising): ...@@ -150,3 +152,116 @@ class WordShuffle(WordNoising):
x2[:length_no_eos, i][torch.from_numpy(permutation)] x2[:length_no_eos, i][torch.from_numpy(permutation)]
) )
return x2, lengths return x2, lengths
class UnsupervisedMTNoising(WordNoising):
"""
Implements the default configuration for noising in UnsupervisedMT
(github.com/facebookresearch/UnsupervisedMT)
"""
def __init__(
self,
dictionary,
max_word_shuffle_distance,
word_dropout_prob,
word_blanking_prob
):
super().__init__(dictionary)
self.max_word_shuffle_distance = max_word_shuffle_distance
self.word_dropout_prob = word_dropout_prob
self.word_blanking_prob = word_blanking_prob
self.word_dropout = WordDropout(dictionary=dictionary)
self.word_shuffle = WordShuffle(dictionary=dictionary)
def noising(self, x, lengths):
# 1. Word Shuffle
noisy_src_tokens, noisy_src_lengths = self.word_shuffle.noising(
x=x,
lengths=lengths,
max_shuffle_distance=self.max_word_shuffle_distance,
)
# 2. Word Dropout
noisy_src_tokens, noisy_src_lengths = self.word_dropout.noising(
x=noisy_src_tokens,
lengths=noisy_src_lengths,
dropout_prob=self.word_dropout_prob,
)
# 3. Word Blanking
noisy_src_tokens, noisy_src_lengths = self.word_dropout.noising(
x=noisy_src_tokens,
lengths=noisy_src_lengths,
dropout_prob=self.word_blanking_prob,
blank_idx=self.dictionary.unk(),
)
return noisy_src_tokens
class NoisingDataset(torch.utils.data.Dataset):
def __init__(
self,
src_dataset,
src_dict,
seed,
noising_class=UnsupervisedMTNoising,
**kwargs,
):
"""
Sets up a noising dataset which takes a src batch, generates
a noisy src using a noising config, and returns the
corresponding {noisy src, original src} batch
Args:
src_dataset: dataset which will be used to build self.src_dataset --
a LanguagePairDataset with src dataset as the source dataset and
None as the target dataset. Should NOT have padding so that
src_lengths are accurately calculated by language_pair_dataset
collate function.
We use language_pair_dataset here to encapsulate the tgt_dataset
so we can re-use the LanguagePairDataset collater to format the
batches in the structure that SequenceGenerator expects.
src_dict: src dict
src_dict: src dictionary
seed: seed to use when generating random noise
noising_class: class to use when initializing noiser
kwargs: noising args for configuring noising to apply
Note that there is no equivalent argparse code for these args
anywhere in our top level train scripts yet. Integration is
still in progress. You can still, however, test out this dataset
functionality with the appropriate args as in the corresponding
unittest: test_noising_dataset.
"""
self.src_dataset = src_dataset
self.src_dict = src_dict
self.noiser = noising_class(
dictionary=src_dict, **kwargs,
)
self.seed = seed
def __getitem__(self, index):
"""
Returns a single noisy sample. Multiple samples are fed to the collater
create a noising dataset batch.
"""
src_tokens = self.src_dataset[index]
src_lengths = torch.LongTensor([len(src_tokens)])
src_tokens = src_tokens.unsqueeze(0)
# Transpose src tokens to fit expected shape of x in noising function
# (batch size, sequence length) -> (sequence length, batch size)
src_tokens_t = torch.t(src_tokens)
with data_utils.numpy_seed(self.seed + index):
noisy_src_tokens = self.noiser.noising(src_tokens_t, src_lengths)
# Transpose back to expected src_tokens format
# (sequence length, 1) -> (1, sequence length)
noisy_src_tokens = torch.t(noisy_src_tokens)
return noisy_src_tokens[0]
def __len__(self):
"""
The length of the noising dataset is the length of src.
"""
return len(self.src_dataset)
...@@ -8,7 +8,15 @@ ...@@ -8,7 +8,15 @@
import torch import torch
import unittest import unittest
from fairseq.data import Dictionary, data_utils, noising import tests.utils as test_utils
from fairseq import utils
from fairseq.data import (
AppendEosDataset,
Dictionary,
data_utils,
noising,
LanguagePairDataset,
)
class TestDataNoising(unittest.TestCase): class TestDataNoising(unittest.TestCase):
...@@ -188,6 +196,119 @@ class TestDataNoising(unittest.TestCase): ...@@ -188,6 +196,119 @@ class TestDataNoising(unittest.TestCase):
) )
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos()) self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def _get_noising_dataset_batch(
self, src_tokens_no_pad, src_dict, use_append_eos_dataset=False
):
"""
Constructs a NoisingDataset and the corresponding
LanguagePairDataset(NoisingDataset(src), src). If we set
use_append_eos_dataset to True, wrap the source dataset in
AppendEosDataset to append EOS to the clean source when using it as the
target. In practice, we should use AppendEosDataset because our models
usually have source without EOS but target with EOS.
"""
src_dataset = test_utils.TestDataset(data=src_tokens_no_pad)
noising_dataset = noising.NoisingDataset(
src_dataset=src_dataset,
src_dict=src_dict,
seed=1234,
max_word_shuffle_distance=3,
word_dropout_prob=0.2,
word_blanking_prob=0.2,
noising_class=noising.UnsupervisedMTNoising,
)
tgt = src_dataset
if use_append_eos_dataset:
tgt = AppendEosDataset(src_dataset, src_dict.eos())
language_pair_dataset = LanguagePairDataset(
src=noising_dataset,
tgt=tgt,
src_sizes=None,
src_dict=src_dict
)
dataloader = torch.utils.data.DataLoader(
dataset=language_pair_dataset,
batch_size=2,
collate_fn=language_pair_dataset.collater,
)
denoising_batch_result = next(iter(dataloader))
return denoising_batch_result
def test_noising_dataset_with_eos(self):
src_dict, src_tokens, _ = self._get_test_data(append_eos=True)
# Format data for src_dataset
src_tokens = torch.t(src_tokens)
src_tokens_no_pad = []
for src_sentence in src_tokens:
src_tokens_no_pad.append(
utils.strip_pad(tensor=src_sentence, pad=src_dict.pad())
)
denoising_batch_result = self._get_noising_dataset_batch(
src_tokens_no_pad=src_tokens_no_pad, src_dict=src_dict
)
eos, pad = src_dict.eos(), src_dict.pad()
# Generated noisy source as source
expected_src = torch.LongTensor(
[[4, 5, 10, 11, 8, 12, 13, eos], [pad, pad, pad, 6, 8, 9, 7, eos]]
)
# Original clean source as target (right-padded)
expected_tgt = torch.LongTensor(
[[4, 5, 10, 11, 8, 12, 13, eos], [6, 7, 8, 9, eos, pad, pad, pad]]
)
generated_src = denoising_batch_result["net_input"]["src_tokens"]
tgt_tokens = denoising_batch_result["target"]
self.assertTensorEqual(expected_src, generated_src)
self.assertTensorEqual(expected_tgt, tgt_tokens)
def test_noising_dataset_without_eos(self):
"""
Similar to test noising dataset with eos except that we have to set
use_append_eos_dataset=True so that we wrap the source dataset in the
AppendEosDataset when using it as the target in LanguagePairDataset.
"""
src_dict, src_tokens, _ = self._get_test_data(append_eos=False)
# Format data for src_dataset
src_tokens = torch.t(src_tokens)
src_tokens_no_pad = []
for src_sentence in src_tokens:
src_tokens_no_pad.append(
utils.strip_pad(tensor=src_sentence, pad=src_dict.pad())
)
denoising_batch_result = self._get_noising_dataset_batch(
src_tokens_no_pad=src_tokens_no_pad,
src_dict=src_dict,
use_append_eos_dataset=True,
)
eos, pad = src_dict.eos(), src_dict.pad()
# Generated noisy source as source
expected_src = torch.LongTensor(
[[4, 5, 10, 11, 8, 12, 13], [pad, pad, pad, 6, 8, 9, 7]]
)
# Original clean source as target (right-padded)
expected_tgt = torch.LongTensor(
[[4, 5, 10, 11, 8, 12, 13, eos], [6, 7, 8, 9, eos, pad, pad, pad]]
)
generated_src = denoising_batch_result["net_input"]["src_tokens"]
tgt_tokens = denoising_batch_result["target"]
self.assertTensorEqual(expected_src, generated_src)
self.assertTensorEqual(expected_tgt, tgt_tokens)
def assertTensorEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertEqual(t1.ne(t2).long().sum(), 0)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment