Commit 8798a240 authored by Liezl Puzon's avatar Liezl Puzon Committed by Facebook Github Bot
Browse files

Have noising account for sentences with and without EOS (#305)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/305

Previously, noising code assumed that every sentence had an EOS which had to be excluded from noising operations (since we shouldn't drop, blank, or shuffle EOS). This logic allows the noising module to handle sentences with EOS and without EOS

Reviewed By: xianxl

Differential Revision: D10114425

fbshipit-source-id: 04ec8547343eb94266bda1ac7fca3d8a1991c9f4
parent 265f42b7
......@@ -22,6 +22,12 @@ class WordNoising(object):
raise NotImplementedError()
def _get_bpe_word_idx(self, x):
"""
Given a list of BPE tokens, for every index in the tokens list,
return the index of the word grouping that it belongs to.
For example, for input x corresponding to ["how", "are", "y@@", "ou"],
return [0, 1, 2, 2].
"""
# x: (T x B)
bpe_end = self.bpe_end[x]
# do a reduce front sum to generate word ids
......@@ -53,9 +59,23 @@ class WordDropout(WordNoising):
# Since dropout probabilities need to apply over non-pad tokens,
# it is not trivial to generate the keep mask without consider
# input lengths; otherwise, this could be done outside the loop
keep = np.random.rand(lengths[i] - 1) >= dropout_prob
# We want to drop whole words based on word_idx grouping
num_words = max(word_idx[:, i]) + 1
# ith example: [x0, x1, ..., eos, pad, ..., pad]
assert x[lengths[i] - 1, i] == self.dictionary.eos()
# We should only generate keep probs for non-EOS tokens. Thus if the
# input sentence ends in EOS, the last word idx is not included in
# the dropout mask generation and we append True to always keep EOS.
# Otherwise, just generate the dropout mask for all word idx
# positions.
has_eos = x[lengths[i] - 1, i] == self.dictionary.eos()
if has_eos: # has eos?
keep = np.random.rand(num_words - 1) >= dropout_prob
keep = np.append(keep, [True]) # keep EOS symbol
else:
keep = np.random.rand(num_words) >= dropout_prob
words = x[:lengths[i], i].tolist()
# TODO: speed up the following loop
......@@ -67,11 +87,13 @@ class WordDropout(WordNoising):
new_s = [w for w in new_s if w is not None]
# we need to have at least one word in the sentence (more than the
# start / end sentence symbols)
if len(new_s) == 1:
new_s.append(words[np.random.randint(0, len(words))])
assert (
len(new_s) >= 2
and new_s[-1] == self.dictionary.eos()
if len(new_s) <= 1:
# insert at beginning in case the only token left is EOS
# EOS should be at end of list.
new_s.insert(0, words[np.random.randint(0, len(words))])
assert len(new_s) >= 1 and (
not has_eos # Either don't have EOS at end or last token is EOS
or (len(new_s) >= 2 and new_s[-1] == self.dictionary.eos())
), "New sentence is invalid."
sentences.append(new_s)
modified_lengths.append(len(new_s))
......@@ -114,13 +136,17 @@ class WordShuffle(WordNoising):
x2 = x.clone()
for i in range(lengths.size(0)):
length_no_eos = lengths[i]
if x[lengths[i] - 1, i] == self.dictionary.eos():
length_no_eos = lengths[i] - 1
# generate a random permutation
scores = word_idx[:lengths[i] - 1, i] + noise[word_idx[:lengths[i] - 1, i], i]
scores = word_idx[:length_no_eos, i] + noise[word_idx[:length_no_eos, i], i]
# ensure no reordering inside a word
scores += 1e-6 * np.arange(lengths[i] - 1)
scores += 1e-6 * np.arange(length_no_eos)
permutation = scores.argsort()
# shuffle words
x2[:lengths[i] - 1, i].copy_(
x2[:lengths[i] - 1, i][torch.from_numpy(permutation)]
x2[:length_no_eos, i].copy_(
x2[:length_no_eos, i][torch.from_numpy(permutation)]
)
return x2, lengths
......@@ -8,11 +8,11 @@
import torch
import unittest
from fairseq.data import data_utils, Dictionary, noising
from fairseq.data import Dictionary, data_utils, noising
class TestDataNoising(unittest.TestCase):
def _get_test_data(self):
def _get_test_data(self, append_eos=True):
vocab = Dictionary()
vocab.add_symbol("he@@")
vocab.add_symbol("llo")
......@@ -30,55 +30,84 @@ class TestDataNoising(unittest.TestCase):
["how", "are", "y@@", "ou"],
]
src_len = [len(x) for x in src_tokens]
x = torch.LongTensor(len(src_tokens), max(src_len) + 1).fill_(vocab.pad())
# If we have to append EOS, we include EOS in counting src length
if append_eos:
src_len = [length + 1 for length in src_len]
x = torch.LongTensor(len(src_tokens), max(src_len)).fill_(vocab.pad())
for i in range(len(src_tokens)):
for j in range(len(src_tokens[i])):
x[i][j] = vocab.index(src_tokens[i][j])
if append_eos:
x[i][j + 1] = vocab.eos()
x = x.transpose(1, 0)
return vocab, x, torch.LongTensor([i + 1 for i in src_len])
def test_word_dropout(self):
vocab, x, x_len = self._get_test_data()
return vocab, x, torch.LongTensor(src_len)
with data_utils.numpy_seed(1234):
noising_gen = noising.WordDropout(vocab)
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2)
def assert_eos_at_end(self, x, x_len, eos):
""" Asserts last token of every sentence in x is EOS """
for i in range(len(x_len)):
self.assertEqual(
x[x_len[i]-1][i],
eos,
f"Expected eos (token id {eos}) at the end of sentence {i} but "
f"got {x[i][-1]} instead"
)
def assert_word_dropout_correct(self, x, x_noised, x_len, l_noised):
# Expect only the first word (2 bpe tokens) of the first example
# was dropped out
self.assertEqual(x_len[0] - 2, l_noised[0])
for i in range(l_noised[0]):
self.assertEqual(x_noised[i][0], x[i+2][0])
self.assertEqual(x_noised[i][0], x[i + 2][0])
def test_word_blank(self):
vocab, x, x_len = self._get_test_data()
def test_word_dropout_with_eos(self):
vocab, x, x_len = self._get_test_data(append_eos=True)
with data_utils.numpy_seed(1234):
noising_gen = noising.WordDropout(vocab)
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk())
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2)
self.assert_word_dropout_correct(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
)
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def assert_word_blanking_correct(self, x, x_noised, x_len, l_noised, unk):
# Expect only the first word (2 bpe tokens) of the first example
# was blanked out
self.assertEqual(x_len[0], l_noised[0])
for i in range(l_noised[0]):
if i < 2:
self.assertEqual(x_noised[i][0], vocab.unk())
self.assertEqual(x_noised[i][0], unk)
else:
self.assertEqual(x_noised[i][0], x[i][0])
def test_word_shuffle(self):
vocab, x, x_len = self._get_test_data()
def test_word_blank_with_eos(self):
vocab, x, x_len = self._get_test_data(append_eos=True)
with data_utils.numpy_seed(1234):
word_shuffle = noising.WordShuffle(vocab)
x_noised, l_noised = word_shuffle.noising(x, x_len, 0)
noising_gen = noising.WordDropout(vocab)
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk())
self.assert_word_blanking_correct(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised, unk=vocab.unk()
)
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def assert_no_shuffle_with_0_distance(self, x, x_noised, x_len, l_noised):
"""
Applies word shuffle with 0 max_shuffle_distance and asserts that no
shuffling happened
"""
for i in range(len(x_len)):
for j in range(x_len[i]):
self.assertEqual(x[j][i], x_noised[j][i])
self.assertEqual(x_len[0], l_noised[0])
x_noised, l_noised = word_shuffle.noising(x, x_len, 3)
def assert_word_shuffle_with_distance_3(self, x, x_noised, x_len, l_noised):
"""
Applies word shuffle with max_shuffle_distance = 3 and asserts that the
shuffling result is as expected. If test data changes, update this func
"""
# Expect the second example has the last three tokens shuffled
# 6, 7, 8, 9 => 6, 8, 9, 7, where (8, 9) is a word
for i in range(x_len[0]):
......@@ -89,6 +118,76 @@ class TestDataNoising(unittest.TestCase):
self.assertEqual(x_len[0], l_noised[0])
self.assertEqual(x_len[1], l_noised[1])
def test_word_shuffle_with_eos(self):
vocab, x, x_len = self._get_test_data(append_eos=True)
with data_utils.numpy_seed(1234):
word_shuffle = noising.WordShuffle(vocab)
x_noised, l_noised = word_shuffle.noising(x, x_len, 0)
self.assert_no_shuffle_with_0_distance(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
)
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
x_noised, l_noised = word_shuffle.noising(x, x_len, 3)
self.assert_word_shuffle_with_distance_3(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
)
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def assert_no_eos_at_end(self, x, x_len, eos):
""" Asserts that the last token of each sentence in x is not EOS """
for i in range(len(x_len)):
self.assertNotEqual(
x[x_len[i]-1][i],
eos,
f"Expected no eos (token id {eos}) at the end of sentence {i}."
)
def test_word_dropout_without_eos(self):
""" Same result as word dropout with eos except no EOS at end"""
vocab, x, x_len = self._get_test_data(append_eos=False)
with data_utils.numpy_seed(1234):
noising_gen = noising.WordDropout(vocab)
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2)
self.assert_word_dropout_correct(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
)
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def test_word_blank_without_eos(self):
""" Same result as word blank with eos except no EOS at end"""
vocab, x, x_len = self._get_test_data(append_eos=False)
with data_utils.numpy_seed(1234):
noising_gen = noising.WordDropout(vocab)
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk())
self.assert_word_blanking_correct(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised, unk=vocab.unk()
)
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def test_word_shuffle_without_eos(self):
""" Same result as word shuffle with eos except no EOS at end """
vocab, x, x_len = self._get_test_data(append_eos=False)
with data_utils.numpy_seed(1234):
word_shuffle = noising.WordShuffle(vocab)
x_noised, l_noised = word_shuffle.noising(x, x_len, 0)
self.assert_no_shuffle_with_0_distance(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
)
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
x_noised, l_noised = word_shuffle.noising(x, x_len, 3)
self.assert_word_shuffle_with_distance_3(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
)
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment