Commit 90c01b3a authored by Xian Li's avatar Xian Li Committed by Facebook Github Bot
Browse files

Extend WordShuffle noising function to apply to non-bpe tokens

Summary:
We'd like to resue the noising functions and DenoisingDataset in
adversarial training. However, current noising functions assume the input are
subword tokens. The goal of this diff is to extend it so the noising can be
applied to word tokens. Since we're mostly interested in the word shuffle
noising, so I only modified the WordShuffle class.

Reviewed By: liezl200

Differential Revision: D10523177

fbshipit-source-id: 1e5d27362850675010e73cd38850c890d42652ab
parent 6117f827
...@@ -18,7 +18,11 @@ class WordNoising(object): ...@@ -18,7 +18,11 @@ class WordNoising(object):
self.bpe_end = np.array([ self.bpe_end = np.array([
not self.dictionary[i].endswith(bpe_cont_marker) not self.dictionary[i].endswith(bpe_cont_marker)
for i in range(len(self.dictionary)) for i in range(len(self.dictionary))
]) ]) if bpe_cont_marker else None
self.get_word_idx = (
self._get_bpe_word_idx if bpe_cont_marker else self._get_token_idx
)
def noising(self, x, lengths, noising_prob=0.0): def noising(self, x, lengths, noising_prob=0.0):
raise NotImplementedError() raise NotImplementedError()
...@@ -37,6 +41,15 @@ class WordNoising(object): ...@@ -37,6 +41,15 @@ class WordNoising(object):
word_idx = word_idx.max(0)[None, :] - word_idx word_idx = word_idx.max(0)[None, :] - word_idx
return word_idx return word_idx
def _get_token_idx(self, x):
"""
This is to extend noising functions to be able to apply to non-bpe
tokens, e.g. word or characters.
"""
x = torch.t(x)
word_idx = np.array([range(len(x_i)) for x_i in x])
return np.transpose(word_idx)
class WordDropout(WordNoising): class WordDropout(WordNoising):
"""Randomly drop input words. If not passing blank_idx (default is None), """Randomly drop input words. If not passing blank_idx (default is None),
...@@ -114,8 +127,8 @@ class WordDropout(WordNoising): ...@@ -114,8 +127,8 @@ class WordDropout(WordNoising):
class WordShuffle(WordNoising): class WordShuffle(WordNoising):
"""Shuffle words by no more than k positions.""" """Shuffle words by no more than k positions."""
def __init__(self, dictionary): def __init__(self, dictionary, bpe_cont_marker="@@"):
super().__init__(dictionary) super().__init__(dictionary, bpe_cont_marker)
def noising(self, x, lengths, max_shuffle_distance=3): def noising(self, x, lengths, max_shuffle_distance=3):
# x: (T x B), lengths: B # x: (T x B), lengths: B
...@@ -134,8 +147,7 @@ class WordShuffle(WordNoising): ...@@ -134,8 +147,7 @@ class WordShuffle(WordNoising):
noise[0] = -1 # do not move start sentence symbol noise[0] = -1 # do not move start sentence symbol
# be sure to shuffle entire words # be sure to shuffle entire words
word_idx = self._get_bpe_word_idx(x) word_idx = self.get_word_idx(x)
x2 = x.clone() x2 = x.clone()
for i in range(lengths.size(0)): for i in range(lengths.size(0)):
length_no_eos = lengths[i] length_no_eos = lengths[i]
......
...@@ -20,23 +20,36 @@ from fairseq.data import ( ...@@ -20,23 +20,36 @@ from fairseq.data import (
class TestDataNoising(unittest.TestCase): class TestDataNoising(unittest.TestCase):
def _get_test_data(self, append_eos=True): def _get_test_data(self, append_eos=True, bpe=True):
vocab = Dictionary() vocab = Dictionary()
vocab.add_symbol("he@@") if bpe:
vocab.add_symbol("llo") vocab.add_symbol("he@@")
vocab.add_symbol("how") vocab.add_symbol("llo")
vocab.add_symbol("are") vocab.add_symbol("how")
vocab.add_symbol("y@@") vocab.add_symbol("are")
vocab.add_symbol("ou") vocab.add_symbol("y@@")
vocab.add_symbol("n@@") vocab.add_symbol("ou")
vocab.add_symbol("ew") vocab.add_symbol("n@@")
vocab.add_symbol("or@@") vocab.add_symbol("ew")
vocab.add_symbol("k") vocab.add_symbol("or@@")
vocab.add_symbol("k")
src_tokens = [
["he@@", "llo", "n@@", "ew", "y@@", "or@@", "k"], src_tokens = [
["how", "are", "y@@", "ou"], ["he@@", "llo", "n@@", "ew", "y@@", "or@@", "k"],
] ["how", "are", "y@@", "ou"],
]
else:
vocab.add_symbol("hello")
vocab.add_symbol("how")
vocab.add_symbol("are")
vocab.add_symbol("you")
vocab.add_symbol("new")
vocab.add_symbol("york")
src_tokens = [
["hello", "new", "york", "you"],
["how", "are", "you", "new", "york"],
]
src_len = [len(x) for x in src_tokens] src_len = [len(x) for x in src_tokens]
# If we have to append EOS, we include EOS in counting src length # If we have to append EOS, we include EOS in counting src length
if append_eos: if append_eos:
...@@ -126,6 +139,22 @@ class TestDataNoising(unittest.TestCase): ...@@ -126,6 +139,22 @@ class TestDataNoising(unittest.TestCase):
self.assertEqual(x_len[0], l_noised[0]) self.assertEqual(x_len[0], l_noised[0])
self.assertEqual(x_len[1], l_noised[1]) self.assertEqual(x_len[1], l_noised[1])
def assert_nonbpe_shuffle_with_distance_3(self, x, x_noised, x_len, l_noised):
"""
Applies word shuffle with max_shuffle_distance = 3 and asserts that the
shuffling result is as expected. If test data changes, update this func
"""
# Expect the first example has the last two tokens shuffled
# Expect the secon example has the second and third tokens shuffled
shuffle_map = {0: 0, 1: 1, 2: 3, 3: 2}
for k, v in shuffle_map.items():
self.assertEqual(x[k][0], x_noised[v][0])
shuffle_map = {0: 0, 1: 2, 2: 1, 3: 3, 4: 4}
for k, v in shuffle_map.items():
self.assertEqual(x[k][1], x_noised[v][1])
self.assertEqual(x_len[0], l_noised[0])
self.assertEqual(x_len[1], l_noised[1])
def test_word_shuffle_with_eos(self): def test_word_shuffle_with_eos(self):
vocab, x, x_len = self._get_test_data(append_eos=True) vocab, x, x_len = self._get_test_data(append_eos=True)
...@@ -144,6 +173,24 @@ class TestDataNoising(unittest.TestCase): ...@@ -144,6 +173,24 @@ class TestDataNoising(unittest.TestCase):
) )
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos()) self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def test_word_shuffle_with_eos_nonbpe(self):
vocab, x, x_len = self._get_test_data(append_eos=True, bpe=False)
with data_utils.numpy_seed(1234):
word_shuffle = noising.WordShuffle(vocab, bpe_cont_marker=None)
x_noised, l_noised = word_shuffle.noising(x, x_len, 0)
self.assert_no_shuffle_with_0_distance(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
)
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
x_noised, l_noised = word_shuffle.noising(x, x_len, 3)
self.assert_nonbpe_shuffle_with_distance_3(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
)
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def assert_no_eos_at_end(self, x, x_len, eos): def assert_no_eos_at_end(self, x, x_len, eos):
""" Asserts that the last token of each sentence in x is not EOS """ """ Asserts that the last token of each sentence in x is not EOS """
for i in range(len(x_len)): for i in range(len(x_len)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment