Unverified Commit 9e147d31 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Deprecate prepare_seq2seq_batch (#10287)



* Deprecate prepare_seq2seq_batch

* Fix last tests

* Apply suggestions from code review
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* More review comments
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent e73a3e18
...@@ -363,9 +363,7 @@ class AbstractMarianIntegrationTest(unittest.TestCase): ...@@ -363,9 +363,7 @@ class AbstractMarianIntegrationTest(unittest.TestCase):
self.assertListEqual(self.expected_text, generated_words) self.assertListEqual(self.expected_text, generated_words)
def translate_src_text(self, **tokenizer_kwargs): def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer.prepare_seq2seq_batch( model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, return_tensors="tf")
src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf"
)
generated_ids = self.model.generate( generated_ids = self.model.generate(
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128 model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128
) )
......
...@@ -330,9 +330,7 @@ class TFMBartModelIntegrationTest(unittest.TestCase): ...@@ -330,9 +330,7 @@ class TFMBartModelIntegrationTest(unittest.TestCase):
self.assertListEqual(self.expected_text, generated_words) self.assertListEqual(self.expected_text, generated_words)
def translate_src_text(self, **tokenizer_kwargs): def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer.prepare_seq2seq_batch( model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, return_tensors="tf")
src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf"
)
generated_ids = self.model.generate( generated_ids = self.model.generate(
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2 model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2
) )
......
...@@ -356,9 +356,7 @@ class TFPegasusIntegrationTests(unittest.TestCase): ...@@ -356,9 +356,7 @@ class TFPegasusIntegrationTests(unittest.TestCase):
assert self.expected_text == generated_words assert self.expected_text == generated_words
def translate_src_text(self, **tokenizer_kwargs): def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer.prepare_seq2seq_batch( model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, return_tensors="tf")
src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf"
)
generated_ids = self.model.generate( generated_ids = self.model.generate(
model_inputs.input_ids, model_inputs.input_ids,
attention_mask=model_inputs.attention_mask, attention_mask=model_inputs.attention_mask,
......
...@@ -86,18 +86,12 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase): ...@@ -86,18 +86,12 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
return BartTokenizerFast.from_pretrained("facebook/bart-large") return BartTokenizerFast.from_pretrained("facebook/bart-large")
@require_torch @require_torch
def test_prepare_seq2seq_batch(self): def test_prepare_batch(self):
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [0, 250, 251, 17818, 13, 39186, 1938, 4, 2] expected_src_tokens = [0, 250, 251, 17818, 13, 39186, 1938, 4, 2]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]: for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch( batch = tokenizer(src_text, max_length=len(expected_src_tokens), padding=True, return_tensors="pt")
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding) self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 9), batch.input_ids.shape) self.assertEqual((2, 9), batch.input_ids.shape)
...@@ -106,12 +100,11 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase): ...@@ -106,12 +100,11 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(expected_src_tokens, result) self.assertListEqual(expected_src_tokens, result)
# Test that special tokens are reset # Test that special tokens are reset
# Test Prepare Seq
@require_torch @require_torch
def test_seq2seq_batch_empty_target_text(self): def test_prepare_batch_empty_target_text(self):
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]: for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt") batch = tokenizer(src_text, padding=True, return_tensors="pt")
# check if input_ids are returned and no labels # check if input_ids are returned and no labels
self.assertIn("input_ids", batch) self.assertIn("input_ids", batch)
self.assertIn("attention_mask", batch) self.assertIn("attention_mask", batch)
...@@ -119,29 +112,21 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase): ...@@ -119,29 +112,21 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
self.assertNotIn("decoder_attention_mask", batch) self.assertNotIn("decoder_attention_mask", batch)
@require_torch @require_torch
def test_seq2seq_batch_max_target_length(self): def test_as_target_tokenizer_target_length(self):
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
tgt_text = [ tgt_text = [
"Summary of the text.", "Summary of the text.",
"Another summary.", "Another summary.",
] ]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]: for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch( with tokenizer.as_target_tokenizer():
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors="pt" targets = tokenizer(tgt_text, max_length=32, padding="max_length", return_tensors="pt")
) self.assertEqual(32, targets["input_ids"].shape[1])
self.assertEqual(32, batch["labels"].shape[1])
# test None max_target_length
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["labels"].shape[1])
@require_torch @require_torch
def test_seq2seq_batch_not_longer_than_maxlen(self): def test_prepare_batch_not_longer_than_maxlen(self):
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]: for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch( batch = tokenizer(
["I am a small frog" * 1024, "I am a small frog"], return_tensors="pt" ["I am a small frog" * 1024, "I am a small frog"], padding=True, truncation=True, return_tensors="pt"
) )
self.assertIsInstance(batch, BatchEncoding) self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 1024)) self.assertEqual(batch.input_ids.shape, (2, 1024))
...@@ -154,9 +139,11 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase): ...@@ -154,9 +139,11 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
"Summary of the text.", "Summary of the text.",
] ]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]: for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors="pt") inputs = tokenizer(src_text, return_tensors="pt")
input_ids = batch["input_ids"] with tokenizer.as_target_tokenizer():
labels = batch["labels"] targets = tokenizer(tgt_text, return_tensors="pt")
input_ids = inputs["input_ids"]
labels = targets["input_ids"]
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item()) self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((labels[:, 0] == tokenizer.bos_token_id).all().item()) self.assertTrue((labels[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item()) self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item())
......
...@@ -38,16 +38,12 @@ class BarthezTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -38,16 +38,12 @@ class BarthezTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.tokenizer = tokenizer self.tokenizer = tokenizer
@require_torch @require_torch
def test_prepare_seq2seq_batch(self): def test_prepare_batch(self):
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [0, 57, 3018, 70307, 91, 2] expected_src_tokens = [0, 57, 3018, 70307, 91, 2]
batch = self.tokenizer.prepare_seq2seq_batch( batch = self.tokenizer(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt" src_text, max_length=len(expected_src_tokens), padding=True, truncation=True, return_tensors="pt"
) )
self.assertIsInstance(batch, BatchEncoding) self.assertIsInstance(batch, BatchEncoding)
......
...@@ -70,7 +70,7 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -70,7 +70,7 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_tokenizer_equivalence_en_de(self): def test_tokenizer_equivalence_en_de(self):
en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de") en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de")
batch = en_de_tokenizer.prepare_seq2seq_batch(["I am a small frog"], return_tensors=None) batch = en_de_tokenizer(["I am a small frog"], return_tensors=None)
self.assertIsInstance(batch, BatchEncoding) self.assertIsInstance(batch, BatchEncoding)
expected = [38, 121, 14, 697, 38848, 0] expected = [38, 121, 14, 697, 38848, 0]
self.assertListEqual(expected, batch.input_ids[0]) self.assertListEqual(expected, batch.input_ids[0])
...@@ -84,12 +84,14 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -84,12 +84,14 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_outputs_not_longer_than_maxlen(self): def test_outputs_not_longer_than_maxlen(self):
tok = self.get_tokenizer() tok = self.get_tokenizer()
batch = tok.prepare_seq2seq_batch(["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK) batch = tok(
["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors=FRAMEWORK
)
self.assertIsInstance(batch, BatchEncoding) self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 512)) self.assertEqual(batch.input_ids.shape, (2, 512))
def test_outputs_can_be_shorter(self): def test_outputs_can_be_shorter(self):
tok = self.get_tokenizer() tok = self.get_tokenizer()
batch_smaller = tok.prepare_seq2seq_batch(["I am a tiny frog", "I am a small frog"], return_tensors=FRAMEWORK) batch_smaller = tok(["I am a tiny frog", "I am a small frog"], padding=True, return_tensors=FRAMEWORK)
self.assertIsInstance(batch_smaller, BatchEncoding) self.assertIsInstance(batch_smaller, BatchEncoding)
self.assertEqual(batch_smaller.input_ids.shape, (2, 10)) self.assertEqual(batch_smaller.input_ids.shape, (2, 10))
...@@ -141,7 +141,9 @@ class MBartEnroIntegrationTest(unittest.TestCase): ...@@ -141,7 +141,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.tokenizer: MBartTokenizer = MBartTokenizer.from_pretrained(cls.checkpoint_name) cls.tokenizer: MBartTokenizer = MBartTokenizer.from_pretrained(
cls.checkpoint_name, src_lang="en_XX", tgt_lang="ro_RO"
)
cls.pad_token_id = 1 cls.pad_token_id = 1
return cls return cls
...@@ -166,10 +168,7 @@ class MBartEnroIntegrationTest(unittest.TestCase): ...@@ -166,10 +168,7 @@ class MBartEnroIntegrationTest(unittest.TestCase):
src_text = ["this is gunna be a long sentence " * 20] src_text = ["this is gunna be a long sentence " * 20]
assert isinstance(src_text[0], str) assert isinstance(src_text[0], str)
desired_max_length = 10 desired_max_length = 10
ids = self.tokenizer.prepare_seq2seq_batch( ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
src_text,
max_length=desired_max_length,
).input_ids[0]
self.assertEqual(ids[-2], 2) self.assertEqual(ids[-2], 2)
self.assertEqual(ids[-1], EN_CODE) self.assertEqual(ids[-1], EN_CODE)
self.assertEqual(len(ids), desired_max_length) self.assertEqual(len(ids), desired_max_length)
...@@ -184,31 +183,36 @@ class MBartEnroIntegrationTest(unittest.TestCase): ...@@ -184,31 +183,36 @@ class MBartEnroIntegrationTest(unittest.TestCase):
new_tok = MBartTokenizer.from_pretrained(tmpdirname) new_tok = MBartTokenizer.from_pretrained(tmpdirname)
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens) self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
# prepare_seq2seq_batch tests below
@require_torch @require_torch
def test_batch_fairseq_parity(self): def test_batch_fairseq_parity(self):
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch( batch = self.tokenizer(self.src_text, padding=True)
self.src_text, tgt_texts=self.tgt_text, return_tensors="pt" with self.tokenizer.as_target_tokenizer():
) targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
for k in batch:
batch[k] = batch[k].tolist()
# batch = {k: v.tolist() for k,v in batch.items()}
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4 # fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
# batch.decoder_inputs_ids[0][0] ==
assert batch.input_ids[1][-2:] == [2, EN_CODE] assert batch.input_ids[1][-2:] == [2, EN_CODE]
assert batch.decoder_input_ids[1][0] == RO_CODE assert batch.decoder_input_ids[1][0] == RO_CODE
assert batch.decoder_input_ids[1][-1] == 2 assert batch.decoder_input_ids[1][-1] == 2
assert batch.labels[1][-2:] == [2, RO_CODE] assert labels[1][-2:].tolist() == [2, RO_CODE]
@require_torch @require_torch
def test_enro_tokenizer_prepare_seq2seq_batch(self): def test_enro_tokenizer_prepare_batch(self):
batch = self.tokenizer.prepare_seq2seq_batch( batch = self.tokenizer(
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), return_tensors="pt" self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
) )
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(
self.tgt_text,
padding=True,
truncation=True,
max_length=len(self.expected_src_tokens),
return_tensors="pt",
)
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
self.assertIsInstance(batch, BatchEncoding) self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 14), batch.input_ids.shape) self.assertEqual((2, 14), batch.input_ids.shape)
...@@ -220,17 +224,12 @@ class MBartEnroIntegrationTest(unittest.TestCase): ...@@ -220,17 +224,12 @@ class MBartEnroIntegrationTest(unittest.TestCase):
self.assertEqual(self.tokenizer.prefix_tokens, []) self.assertEqual(self.tokenizer.prefix_tokens, [])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE]) self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
def test_seq2seq_max_target_length(self): def test_seq2seq_max_length(self):
batch = self.tokenizer.prepare_seq2seq_batch( batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10, return_tensors="pt" with self.tokenizer.as_target_tokenizer():
) targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3) self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10) self.assertEqual(batch.decoder_input_ids.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
...@@ -129,10 +129,7 @@ class MBartOneToManyIntegrationTest(unittest.TestCase): ...@@ -129,10 +129,7 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
src_text = ["this is gunna be a long sentence " * 20] src_text = ["this is gunna be a long sentence " * 20]
assert isinstance(src_text[0], str) assert isinstance(src_text[0], str)
desired_max_length = 10 desired_max_length = 10
ids = self.tokenizer.prepare_seq2seq_batch( ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
src_text,
max_length=desired_max_length,
).input_ids[0]
self.assertEqual(ids[0], EN_CODE) self.assertEqual(ids[0], EN_CODE)
self.assertEqual(ids[-1], 2) self.assertEqual(ids[-1], 2)
self.assertEqual(len(ids), desired_max_length) self.assertEqual(len(ids), desired_max_length)
...@@ -147,32 +144,38 @@ class MBartOneToManyIntegrationTest(unittest.TestCase): ...@@ -147,32 +144,38 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
new_tok = MBart50Tokenizer.from_pretrained(tmpdirname) new_tok = MBart50Tokenizer.from_pretrained(tmpdirname)
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens) self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
# prepare_seq2seq_batch tests below
@require_torch @require_torch
def test_batch_fairseq_parity(self): def test_batch_fairseq_parity(self):
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch( batch = self.tokenizer(self.src_text, padding=True)
self.src_text, tgt_texts=self.tgt_text, return_tensors="pt" with self.tokenizer.as_target_tokenizer():
) targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
labels = labels.tolist()
for k in batch:
batch[k] = batch[k].tolist()
# batch = {k: v.tolist() for k,v in batch.items()}
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4 # fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
# batch.decoder_inputs_ids[0][0] ==
assert batch.input_ids[1][0] == EN_CODE assert batch.input_ids[1][0] == EN_CODE
assert batch.input_ids[1][-1] == 2 assert batch.input_ids[1][-1] == 2
assert batch.labels[1][0] == RO_CODE assert labels[1][0] == RO_CODE
assert batch.labels[1][-1] == 2 assert labels[1][-1] == 2
assert batch.decoder_input_ids[1][:2] == [2, RO_CODE] assert batch.decoder_input_ids[1][:2] == [2, RO_CODE]
@require_torch @require_torch
def test_tokenizer_prepare_seq2seq_batch(self): def test_tokenizer_prepare_batch(self):
batch = self.tokenizer.prepare_seq2seq_batch( batch = self.tokenizer(
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), return_tensors="pt" self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
) )
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(
self.tgt_text,
padding=True,
truncation=True,
max_length=len(self.expected_src_tokens),
return_tensors="pt",
)
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
self.assertIsInstance(batch, BatchEncoding) self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 14), batch.input_ids.shape) self.assertEqual((2, 14), batch.input_ids.shape)
...@@ -185,16 +188,11 @@ class MBartOneToManyIntegrationTest(unittest.TestCase): ...@@ -185,16 +188,11 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id]) self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
def test_seq2seq_max_target_length(self): def test_seq2seq_max_target_length(self):
batch = self.tokenizer.prepare_seq2seq_batch( batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10, return_tensors="pt" with self.tokenizer.as_target_tokenizer():
) targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3) self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10) self.assertEqual(batch.decoder_input_ids.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
...@@ -86,11 +86,13 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -86,11 +86,13 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_large_seq2seq_truncation(self): def test_large_seq2seq_truncation(self):
src_texts = ["This is going to be way too long." * 150, "short example"] src_texts = ["This is going to be way too long." * 150, "short example"]
tgt_texts = ["not super long but more than 5 tokens", "tiny"] tgt_texts = ["not super long but more than 5 tokens", "tiny"]
batch = self._large_tokenizer.prepare_seq2seq_batch( batch = self._large_tokenizer(src_texts, padding=True, truncation=True, return_tensors="pt")
src_texts, tgt_texts=tgt_texts, max_target_length=5, return_tensors="pt" with self._large_tokenizer.as_target_tokenizer():
targets = self._large_tokenizer(
tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
) )
assert batch.input_ids.shape == (2, 1024) assert batch.input_ids.shape == (2, 1024)
assert batch.attention_mask.shape == (2, 1024) assert batch.attention_mask.shape == (2, 1024)
assert "labels" in batch # because tgt_texts was specified assert targets["input_ids"].shape == (2, 5)
assert batch.labels.shape == (2, 5) assert len(batch) == 2 # input_ids, attention_mask.
assert len(batch) == 3 # input_ids, attention_mask, labels. Other things make by BartModel
...@@ -152,20 +152,12 @@ class ProphetNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -152,20 +152,12 @@ class ProphetNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
@require_torch @require_torch
def test_prepare_seq2seq_batch(self): def test_prepare_batch(self):
tokenizer = self.tokenizer_class.from_pretrained("microsoft/prophetnet-large-uncased") tokenizer = self.tokenizer_class.from_pretrained("microsoft/prophetnet-large-uncased")
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [1037, 2146, 20423, 2005, 7680, 7849, 3989, 1012, 102] expected_src_tokens = [1037, 2146, 20423, 2005, 7680, 7849, 3989, 1012, 102]
batch = tokenizer.prepare_seq2seq_batch( batch = tokenizer(src_text, padding=True, return_tensors="pt")
src_text,
tgt_texts=tgt_text,
return_tensors="pt",
)
self.assertIsInstance(batch, BatchEncoding) self.assertIsInstance(batch, BatchEncoding)
result = list(batch.input_ids.numpy()[0]) result = list(batch.input_ids.numpy()[0])
self.assertListEqual(expected_src_tokens, result) self.assertListEqual(expected_src_tokens, result)
......
...@@ -151,19 +151,11 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -151,19 +151,11 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""]) batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""])
self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"]) self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"])
def test_prepare_seq2seq_batch(self): def test_prepare_batch(self):
tokenizer = self.t5_base_tokenizer tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id] expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id]
batch = tokenizer.prepare_seq2seq_batch( batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
src_text,
tgt_texts=tgt_text,
return_tensors=FRAMEWORK,
)
self.assertIsInstance(batch, BatchEncoding) self.assertIsInstance(batch, BatchEncoding)
result = list(batch.input_ids.numpy()[0]) result = list(batch.input_ids.numpy()[0])
self.assertListEqual(expected_src_tokens, result) self.assertListEqual(expected_src_tokens, result)
...@@ -174,36 +166,30 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -174,36 +166,30 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_empty_target_text(self): def test_empty_target_text(self):
tokenizer = self.t5_base_tokenizer tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors=FRAMEWORK) batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
# check if input_ids are returned and no decoder_input_ids # check if input_ids are returned and no decoder_input_ids
self.assertIn("input_ids", batch) self.assertIn("input_ids", batch)
self.assertIn("attention_mask", batch) self.assertIn("attention_mask", batch)
self.assertNotIn("decoder_input_ids", batch) self.assertNotIn("decoder_input_ids", batch)
self.assertNotIn("decoder_attention_mask", batch) self.assertNotIn("decoder_attention_mask", batch)
def test_max_target_length(self): def test_max_length(self):
tokenizer = self.t5_base_tokenizer tokenizer = self.t5_base_tokenizer
src_text = ["A short paragraph for summarization.", "Another short paragraph for summarization."]
tgt_text = [ tgt_text = [
"Summary of the text.", "Summary of the text.",
"Another summary.", "Another summary.",
] ]
batch = tokenizer.prepare_seq2seq_batch( with tokenizer.as_target_tokenizer():
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors=FRAMEWORK targets = tokenizer(
) tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
self.assertEqual(32, batch["labels"].shape[1])
# test None max_target_length
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors=FRAMEWORK
) )
self.assertEqual(32, batch["labels"].shape[1]) self.assertEqual(32, targets["input_ids"].shape[1])
def test_outputs_not_longer_than_maxlen(self): def test_outputs_not_longer_than_maxlen(self):
tokenizer = self.t5_base_tokenizer tokenizer = self.t5_base_tokenizer
batch = tokenizer.prepare_seq2seq_batch( batch = tokenizer(
["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK ["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors=FRAMEWORK
) )
self.assertIsInstance(batch, BatchEncoding) self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 512)) self.assertEqual(batch.input_ids.shape, (2, 512))
...@@ -215,13 +201,12 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -215,13 +201,12 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1] expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1]
expected_tgt_tokens = [20698, 13, 8, 1499, 5, 1] expected_tgt_tokens = [20698, 13, 8, 1499, 5, 1]
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK) batch = tokenizer(src_text)
with tokenizer.as_target_tokenizer():
src_ids = list(batch.input_ids.numpy()[0]) targets = tokenizer(tgt_text)
tgt_ids = list(batch.labels.numpy()[0])
self.assertEqual(expected_src_tokens, src_ids) self.assertEqual(expected_src_tokens, batch["input_ids"][0])
self.assertEqual(expected_tgt_tokens, tgt_ids) self.assertEqual(expected_tgt_tokens, targets["input_ids"][0])
def test_token_type_ids(self): def test_token_type_ids(self):
src_text_1 = ["A first paragraph for summarization."] src_text_1 = ["A first paragraph for summarization."]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment