Commit 8798a240 authored by Liezl Puzon's avatar Liezl Puzon Committed by Facebook Github Bot
Browse files

Have noising account for sentences with and without EOS (#305)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/305

Previously, noising code assumed that every sentence had an EOS which had to be excluded from noising operations (since we shouldn't drop, blank, or shuffle EOS). This logic allows the noising module to handle sentences with EOS and without EOS

Reviewed By: xianxl

Differential Revision: D10114425

fbshipit-source-id: 04ec8547343eb94266bda1ac7fca3d8a1991c9f4
parent 265f42b7
...@@ -22,6 +22,12 @@ class WordNoising(object): ...@@ -22,6 +22,12 @@ class WordNoising(object):
raise NotImplementedError() raise NotImplementedError()
def _get_bpe_word_idx(self, x): def _get_bpe_word_idx(self, x):
"""
Given a list of BPE tokens, for every index in the tokens list,
return the index of the word grouping that it belongs to.
For example, for input x corresponding to ["how", "are", "y@@", "ou"],
return [0, 1, 2, 2].
"""
# x: (T x B) # x: (T x B)
bpe_end = self.bpe_end[x] bpe_end = self.bpe_end[x]
# do a reduce front sum to generate word ids # do a reduce front sum to generate word ids
...@@ -53,9 +59,23 @@ class WordDropout(WordNoising): ...@@ -53,9 +59,23 @@ class WordDropout(WordNoising):
# Since dropout probabilities need to apply over non-pad tokens, # Since dropout probabilities need to apply over non-pad tokens,
# it is not trivial to generate the keep mask without consider # it is not trivial to generate the keep mask without consider
# input lengths; otherwise, this could be done outside the loop # input lengths; otherwise, this could be done outside the loop
keep = np.random.rand(lengths[i] - 1) >= dropout_prob
# We want to drop whole words based on word_idx grouping
num_words = max(word_idx[:, i]) + 1
# ith example: [x0, x1, ..., eos, pad, ..., pad] # ith example: [x0, x1, ..., eos, pad, ..., pad]
assert x[lengths[i] - 1, i] == self.dictionary.eos() # We should only generate keep probs for non-EOS tokens. Thus if the
# input sentence ends in EOS, the last word idx is not included in
# the dropout mask generation and we append True to always keep EOS.
# Otherwise, just generate the dropout mask for all word idx
# positions.
has_eos = x[lengths[i] - 1, i] == self.dictionary.eos()
if has_eos: # has eos?
keep = np.random.rand(num_words - 1) >= dropout_prob
keep = np.append(keep, [True]) # keep EOS symbol
else:
keep = np.random.rand(num_words) >= dropout_prob
words = x[:lengths[i], i].tolist() words = x[:lengths[i], i].tolist()
# TODO: speed up the following loop # TODO: speed up the following loop
...@@ -67,11 +87,13 @@ class WordDropout(WordNoising): ...@@ -67,11 +87,13 @@ class WordDropout(WordNoising):
new_s = [w for w in new_s if w is not None] new_s = [w for w in new_s if w is not None]
# we need to have at least one word in the sentence (more than the # we need to have at least one word in the sentence (more than the
# start / end sentence symbols) # start / end sentence symbols)
if len(new_s) == 1: if len(new_s) <= 1:
new_s.append(words[np.random.randint(0, len(words))]) # insert at beginning in case the only token left is EOS
assert ( # EOS should be at end of list.
len(new_s) >= 2 new_s.insert(0, words[np.random.randint(0, len(words))])
and new_s[-1] == self.dictionary.eos() assert len(new_s) >= 1 and (
not has_eos # Either don't have EOS at end or last token is EOS
or (len(new_s) >= 2 and new_s[-1] == self.dictionary.eos())
), "New sentence is invalid." ), "New sentence is invalid."
sentences.append(new_s) sentences.append(new_s)
modified_lengths.append(len(new_s)) modified_lengths.append(len(new_s))
...@@ -114,13 +136,17 @@ class WordShuffle(WordNoising): ...@@ -114,13 +136,17 @@ class WordShuffle(WordNoising):
x2 = x.clone() x2 = x.clone()
for i in range(lengths.size(0)): for i in range(lengths.size(0)):
length_no_eos = lengths[i]
if x[lengths[i] - 1, i] == self.dictionary.eos():
length_no_eos = lengths[i] - 1
# generate a random permutation # generate a random permutation
scores = word_idx[:lengths[i] - 1, i] + noise[word_idx[:lengths[i] - 1, i], i] scores = word_idx[:length_no_eos, i] + noise[word_idx[:length_no_eos, i], i]
# ensure no reordering inside a word # ensure no reordering inside a word
scores += 1e-6 * np.arange(lengths[i] - 1) scores += 1e-6 * np.arange(length_no_eos)
permutation = scores.argsort() permutation = scores.argsort()
# shuffle words # shuffle words
x2[:lengths[i] - 1, i].copy_( x2[:length_no_eos, i].copy_(
x2[:lengths[i] - 1, i][torch.from_numpy(permutation)] x2[:length_no_eos, i][torch.from_numpy(permutation)]
) )
return x2, lengths return x2, lengths
...@@ -8,11 +8,11 @@ ...@@ -8,11 +8,11 @@
import torch import torch
import unittest import unittest
from fairseq.data import data_utils, Dictionary, noising from fairseq.data import Dictionary, data_utils, noising
class TestDataNoising(unittest.TestCase): class TestDataNoising(unittest.TestCase):
def _get_test_data(self): def _get_test_data(self, append_eos=True):
vocab = Dictionary() vocab = Dictionary()
vocab.add_symbol("he@@") vocab.add_symbol("he@@")
vocab.add_symbol("llo") vocab.add_symbol("llo")
...@@ -30,64 +30,163 @@ class TestDataNoising(unittest.TestCase): ...@@ -30,64 +30,163 @@ class TestDataNoising(unittest.TestCase):
["how", "are", "y@@", "ou"], ["how", "are", "y@@", "ou"],
] ]
src_len = [len(x) for x in src_tokens] src_len = [len(x) for x in src_tokens]
x = torch.LongTensor(len(src_tokens), max(src_len) + 1).fill_(vocab.pad()) # If we have to append EOS, we include EOS in counting src length
if append_eos:
src_len = [length + 1 for length in src_len]
x = torch.LongTensor(len(src_tokens), max(src_len)).fill_(vocab.pad())
for i in range(len(src_tokens)): for i in range(len(src_tokens)):
for j in range(len(src_tokens[i])): for j in range(len(src_tokens[i])):
x[i][j] = vocab.index(src_tokens[i][j]) x[i][j] = vocab.index(src_tokens[i][j])
x[i][j + 1] = vocab.eos() if append_eos:
x[i][j + 1] = vocab.eos()
x = x.transpose(1, 0) x = x.transpose(1, 0)
return vocab, x, torch.LongTensor([i + 1 for i in src_len]) return vocab, x, torch.LongTensor(src_len)
def assert_eos_at_end(self, x, x_len, eos):
""" Asserts last token of every sentence in x is EOS """
for i in range(len(x_len)):
self.assertEqual(
x[x_len[i]-1][i],
eos,
f"Expected eos (token id {eos}) at the end of sentence {i} but "
f"got {x[i][-1]} instead"
)
def assert_word_dropout_correct(self, x, x_noised, x_len, l_noised):
# Expect only the first word (2 bpe tokens) of the first example
# was dropped out
self.assertEqual(x_len[0] - 2, l_noised[0])
for i in range(l_noised[0]):
self.assertEqual(x_noised[i][0], x[i + 2][0])
def test_word_dropout_with_eos(self):
vocab, x, x_len = self._get_test_data(append_eos=True)
with data_utils.numpy_seed(1234):
noising_gen = noising.WordDropout(vocab)
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2)
self.assert_word_dropout_correct(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
)
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def assert_word_blanking_correct(self, x, x_noised, x_len, l_noised, unk):
# Expect only the first word (2 bpe tokens) of the first example
# was blanked out
self.assertEqual(x_len[0], l_noised[0])
for i in range(l_noised[0]):
if i < 2:
self.assertEqual(x_noised[i][0], unk)
else:
self.assertEqual(x_noised[i][0], x[i][0])
def test_word_blank_with_eos(self):
vocab, x, x_len = self._get_test_data(append_eos=True)
with data_utils.numpy_seed(1234):
noising_gen = noising.WordDropout(vocab)
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk())
self.assert_word_blanking_correct(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised, unk=vocab.unk()
)
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def assert_no_shuffle_with_0_distance(self, x, x_noised, x_len, l_noised):
"""
Applies word shuffle with 0 max_shuffle_distance and asserts that no
shuffling happened
"""
for i in range(len(x_len)):
for j in range(x_len[i]):
self.assertEqual(x[j][i], x_noised[j][i])
self.assertEqual(x_len[0], l_noised[0])
def assert_word_shuffle_with_distance_3(self, x, x_noised, x_len, l_noised):
"""
Applies word shuffle with max_shuffle_distance = 3 and asserts that the
shuffling result is as expected. If test data changes, update this func
"""
# Expect the second example has the last three tokens shuffled
# 6, 7, 8, 9 => 6, 8, 9, 7, where (8, 9) is a word
for i in range(x_len[0]):
self.assertEqual(x[i][0], x_noised[i][0])
shuffle_map = {0: 0, 1: 3, 2: 1, 3: 2}
for k, v in shuffle_map.items():
self.assertEqual(x[k][1], x_noised[v][1])
self.assertEqual(x_len[0], l_noised[0])
self.assertEqual(x_len[1], l_noised[1])
def test_word_shuffle_with_eos(self):
vocab, x, x_len = self._get_test_data(append_eos=True)
with data_utils.numpy_seed(1234):
word_shuffle = noising.WordShuffle(vocab)
def test_word_dropout(self): x_noised, l_noised = word_shuffle.noising(x, x_len, 0)
vocab, x, x_len = self._get_test_data() self.assert_no_shuffle_with_0_distance(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
)
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
x_noised, l_noised = word_shuffle.noising(x, x_len, 3)
self.assert_word_shuffle_with_distance_3(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
)
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def assert_no_eos_at_end(self, x, x_len, eos):
""" Asserts that the last token of each sentence in x is not EOS """
for i in range(len(x_len)):
self.assertNotEqual(
x[x_len[i]-1][i],
eos,
f"Expected no eos (token id {eos}) at the end of sentence {i}."
)
def test_word_dropout_without_eos(self):
""" Same result as word dropout with eos except no EOS at end"""
vocab, x, x_len = self._get_test_data(append_eos=False)
with data_utils.numpy_seed(1234): with data_utils.numpy_seed(1234):
noising_gen = noising.WordDropout(vocab) noising_gen = noising.WordDropout(vocab)
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2) x_noised, l_noised = noising_gen.noising(x, x_len, 0.2)
# Expect only the first word (2 bpe tokens) of the first example self.assert_word_dropout_correct(
# was dropped out x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
self.assertEqual(x_len[0] - 2, l_noised[0]) )
for i in range(l_noised[0]): self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
self.assertEqual(x_noised[i][0], x[i+2][0])
def test_word_blank(self): def test_word_blank_without_eos(self):
vocab, x, x_len = self._get_test_data() """ Same result as word blank with eos except no EOS at end"""
vocab, x, x_len = self._get_test_data(append_eos=False)
with data_utils.numpy_seed(1234): with data_utils.numpy_seed(1234):
noising_gen = noising.WordDropout(vocab) noising_gen = noising.WordDropout(vocab)
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk()) x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk())
# Expect only the first word (2 bpe tokens) of the first example self.assert_word_blanking_correct(
# was blanked out x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised, unk=vocab.unk()
self.assertEqual(x_len[0], l_noised[0]) )
for i in range(l_noised[0]): self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
if i < 2:
self.assertEqual(x_noised[i][0], vocab.unk()) def test_word_shuffle_without_eos(self):
else: """ Same result as word shuffle with eos except no EOS at end """
self.assertEqual(x_noised[i][0], x[i][0]) vocab, x, x_len = self._get_test_data(append_eos=False)
def test_word_shuffle(self):
vocab, x, x_len = self._get_test_data()
with data_utils.numpy_seed(1234): with data_utils.numpy_seed(1234):
word_shuffle = noising.WordShuffle(vocab) word_shuffle = noising.WordShuffle(vocab)
x_noised, l_noised = word_shuffle.noising(x, x_len, 0) x_noised, l_noised = word_shuffle.noising(x, x_len, 0)
for i in range(len(x_len)): self.assert_no_shuffle_with_0_distance(
for j in range(x_len[i]): x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
self.assertEqual(x[j][i], x_noised[j][i]) )
self.assertEqual(x_len[0], l_noised[0]) self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
x_noised, l_noised = word_shuffle.noising(x, x_len, 3) x_noised, l_noised = word_shuffle.noising(x, x_len, 3)
# Expect the second example has the last three tokens shuffled self.assert_word_shuffle_with_distance_3(
# 6, 7, 8, 9 => 6, 8, 9, 7, where (8, 9) is a word x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
for i in range(x_len[0]): )
self.assertEqual(x[i][0], x_noised[i][0]) self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
shuffle_map = {0: 0, 1: 3, 2: 1, 3: 2}
for k, v in shuffle_map.items():
self.assertEqual(x[k][1], x_noised[v][1])
self.assertEqual(x_len[0], l_noised[0])
self.assertEqual(x_len[1], l_noised[1])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment