Unverified Commit 9e147d31 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Deprecate prepare_seq2seq_batch (#10287)



* Deprecate prepare_seq2seq_batch

* Fix last tests

* Apply suggestions from code review
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* More review comments
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent e73a3e18
......@@ -363,9 +363,7 @@ class AbstractMarianIntegrationTest(unittest.TestCase):
self.assertListEqual(self.expected_text, generated_words)
def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer.prepare_seq2seq_batch(
src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf"
)
model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, return_tensors="tf")
generated_ids = self.model.generate(
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128
)
......
......@@ -330,9 +330,7 @@ class TFMBartModelIntegrationTest(unittest.TestCase):
self.assertListEqual(self.expected_text, generated_words)
def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer.prepare_seq2seq_batch(
src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf"
)
model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, return_tensors="tf")
generated_ids = self.model.generate(
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2
)
......
......@@ -356,9 +356,7 @@ class TFPegasusIntegrationTests(unittest.TestCase):
assert self.expected_text == generated_words
def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer.prepare_seq2seq_batch(
src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf"
)
model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, return_tensors="tf")
generated_ids = self.model.generate(
model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
......
......@@ -86,18 +86,12 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
return BartTokenizerFast.from_pretrained("facebook/bart-large")
@require_torch
def test_prepare_seq2seq_batch(self):
def test_prepare_batch(self):
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [0, 250, 251, 17818, 13, 39186, 1938, 4, 2]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt"
)
batch = tokenizer(src_text, max_length=len(expected_src_tokens), padding=True, return_tensors="pt")
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 9), batch.input_ids.shape)
......@@ -106,12 +100,11 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(expected_src_tokens, result)
# Test that special tokens are reset
# Test Prepare Seq
@require_torch
def test_seq2seq_batch_empty_target_text(self):
def test_prepare_batch_empty_target_text(self):
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")
batch = tokenizer(src_text, padding=True, return_tensors="pt")
# check if input_ids are returned and no labels
self.assertIn("input_ids", batch)
self.assertIn("attention_mask", batch)
......@@ -119,29 +112,21 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
self.assertNotIn("decoder_attention_mask", batch)
@require_torch
def test_seq2seq_batch_max_target_length(self):
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
def test_as_target_tokenizer_target_length(self):
tgt_text = [
"Summary of the text.",
"Another summary.",
]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["labels"].shape[1])
# test None max_target_length
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["labels"].shape[1])
with tokenizer.as_target_tokenizer():
targets = tokenizer(tgt_text, max_length=32, padding="max_length", return_tensors="pt")
self.assertEqual(32, targets["input_ids"].shape[1])
@require_torch
def test_seq2seq_batch_not_longer_than_maxlen(self):
def test_prepare_batch_not_longer_than_maxlen(self):
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(
["I am a small frog" * 1024, "I am a small frog"], return_tensors="pt"
batch = tokenizer(
["I am a small frog" * 1024, "I am a small frog"], padding=True, truncation=True, return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 1024))
......@@ -154,9 +139,11 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
"Summary of the text.",
]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors="pt")
input_ids = batch["input_ids"]
labels = batch["labels"]
inputs = tokenizer(src_text, return_tensors="pt")
with tokenizer.as_target_tokenizer():
targets = tokenizer(tgt_text, return_tensors="pt")
input_ids = inputs["input_ids"]
labels = targets["input_ids"]
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((labels[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item())
......
......@@ -38,16 +38,12 @@ class BarthezTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.tokenizer = tokenizer
@require_torch
def test_prepare_seq2seq_batch(self):
def test_prepare_batch(self):
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [0, 57, 3018, 70307, 91, 2]
batch = self.tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt"
batch = self.tokenizer(
src_text, max_length=len(expected_src_tokens), padding=True, truncation=True, return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)
......
......@@ -70,7 +70,7 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_tokenizer_equivalence_en_de(self):
en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de")
batch = en_de_tokenizer.prepare_seq2seq_batch(["I am a small frog"], return_tensors=None)
batch = en_de_tokenizer(["I am a small frog"], return_tensors=None)
self.assertIsInstance(batch, BatchEncoding)
expected = [38, 121, 14, 697, 38848, 0]
self.assertListEqual(expected, batch.input_ids[0])
......@@ -84,12 +84,14 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_outputs_not_longer_than_maxlen(self):
tok = self.get_tokenizer()
batch = tok.prepare_seq2seq_batch(["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK)
batch = tok(
["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors=FRAMEWORK
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 512))
def test_outputs_can_be_shorter(self):
tok = self.get_tokenizer()
batch_smaller = tok.prepare_seq2seq_batch(["I am a tiny frog", "I am a small frog"], return_tensors=FRAMEWORK)
batch_smaller = tok(["I am a tiny frog", "I am a small frog"], padding=True, return_tensors=FRAMEWORK)
self.assertIsInstance(batch_smaller, BatchEncoding)
self.assertEqual(batch_smaller.input_ids.shape, (2, 10))
......@@ -141,7 +141,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.tokenizer: MBartTokenizer = MBartTokenizer.from_pretrained(cls.checkpoint_name)
cls.tokenizer: MBartTokenizer = MBartTokenizer.from_pretrained(
cls.checkpoint_name, src_lang="en_XX", tgt_lang="ro_RO"
)
cls.pad_token_id = 1
return cls
......@@ -166,10 +168,7 @@ class MBartEnroIntegrationTest(unittest.TestCase):
src_text = ["this is gunna be a long sentence " * 20]
assert isinstance(src_text[0], str)
desired_max_length = 10
ids = self.tokenizer.prepare_seq2seq_batch(
src_text,
max_length=desired_max_length,
).input_ids[0]
ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
self.assertEqual(ids[-2], 2)
self.assertEqual(ids[-1], EN_CODE)
self.assertEqual(len(ids), desired_max_length)
......@@ -184,31 +183,36 @@ class MBartEnroIntegrationTest(unittest.TestCase):
new_tok = MBartTokenizer.from_pretrained(tmpdirname)
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
# prepare_seq2seq_batch tests below
@require_torch
def test_batch_fairseq_parity(self):
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
batch = self.tokenizer(self.src_text, padding=True)
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
for k in batch:
batch[k] = batch[k].tolist()
# batch = {k: v.tolist() for k,v in batch.items()}
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
# batch.decoder_inputs_ids[0][0] ==
assert batch.input_ids[1][-2:] == [2, EN_CODE]
assert batch.decoder_input_ids[1][0] == RO_CODE
assert batch.decoder_input_ids[1][-1] == 2
assert batch.labels[1][-2:] == [2, RO_CODE]
assert labels[1][-2:].tolist() == [2, RO_CODE]
@require_torch
def test_enro_tokenizer_prepare_seq2seq_batch(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), return_tensors="pt"
def test_enro_tokenizer_prepare_batch(self):
batch = self.tokenizer(
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(
self.tgt_text,
padding=True,
truncation=True,
max_length=len(self.expected_src_tokens),
return_tensors="pt",
)
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 14), batch.input_ids.shape)
......@@ -220,17 +224,12 @@ class MBartEnroIntegrationTest(unittest.TestCase):
self.assertEqual(self.tokenizer.prefix_tokens, [])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
def test_seq2seq_max_target_length(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
def test_seq2seq_max_length(self):
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
......@@ -129,10 +129,7 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
src_text = ["this is gunna be a long sentence " * 20]
assert isinstance(src_text[0], str)
desired_max_length = 10
ids = self.tokenizer.prepare_seq2seq_batch(
src_text,
max_length=desired_max_length,
).input_ids[0]
ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
self.assertEqual(ids[0], EN_CODE)
self.assertEqual(ids[-1], 2)
self.assertEqual(len(ids), desired_max_length)
......@@ -147,32 +144,38 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
new_tok = MBart50Tokenizer.from_pretrained(tmpdirname)
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
# prepare_seq2seq_batch tests below
@require_torch
def test_batch_fairseq_parity(self):
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
batch = self.tokenizer(self.src_text, padding=True)
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
labels = labels.tolist()
for k in batch:
batch[k] = batch[k].tolist()
# batch = {k: v.tolist() for k,v in batch.items()}
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
# batch.decoder_inputs_ids[0][0] ==
assert batch.input_ids[1][0] == EN_CODE
assert batch.input_ids[1][-1] == 2
assert batch.labels[1][0] == RO_CODE
assert batch.labels[1][-1] == 2
assert labels[1][0] == RO_CODE
assert labels[1][-1] == 2
assert batch.decoder_input_ids[1][:2] == [2, RO_CODE]
@require_torch
def test_tokenizer_prepare_seq2seq_batch(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), return_tensors="pt"
def test_tokenizer_prepare_batch(self):
batch = self.tokenizer(
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(
self.tgt_text,
padding=True,
truncation=True,
max_length=len(self.expected_src_tokens),
return_tensors="pt",
)
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 14), batch.input_ids.shape)
......@@ -185,16 +188,11 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
def test_seq2seq_max_target_length(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
......@@ -86,11 +86,13 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_large_seq2seq_truncation(self):
src_texts = ["This is going to be way too long." * 150, "short example"]
tgt_texts = ["not super long but more than 5 tokens", "tiny"]
batch = self._large_tokenizer.prepare_seq2seq_batch(
src_texts, tgt_texts=tgt_texts, max_target_length=5, return_tensors="pt"
)
batch = self._large_tokenizer(src_texts, padding=True, truncation=True, return_tensors="pt")
with self._large_tokenizer.as_target_tokenizer():
targets = self._large_tokenizer(
tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
)
assert batch.input_ids.shape == (2, 1024)
assert batch.attention_mask.shape == (2, 1024)
assert "labels" in batch # because tgt_texts was specified
assert batch.labels.shape == (2, 5)
assert len(batch) == 3 # input_ids, attention_mask, labels. Other things make by BartModel
assert targets["input_ids"].shape == (2, 5)
assert len(batch) == 2 # input_ids, attention_mask.
......@@ -152,20 +152,12 @@ class ProphetNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
@require_torch
def test_prepare_seq2seq_batch(self):
def test_prepare_batch(self):
tokenizer = self.tokenizer_class.from_pretrained("microsoft/prophetnet-large-uncased")
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [1037, 2146, 20423, 2005, 7680, 7849, 3989, 1012, 102]
batch = tokenizer.prepare_seq2seq_batch(
src_text,
tgt_texts=tgt_text,
return_tensors="pt",
)
batch = tokenizer(src_text, padding=True, return_tensors="pt")
self.assertIsInstance(batch, BatchEncoding)
result = list(batch.input_ids.numpy()[0])
self.assertListEqual(expected_src_tokens, result)
......
......@@ -151,19 +151,11 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""])
self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"])
def test_prepare_seq2seq_batch(self):
def test_prepare_batch(self):
tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id]
batch = tokenizer.prepare_seq2seq_batch(
src_text,
tgt_texts=tgt_text,
return_tensors=FRAMEWORK,
)
batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
self.assertIsInstance(batch, BatchEncoding)
result = list(batch.input_ids.numpy()[0])
self.assertListEqual(expected_src_tokens, result)
......@@ -174,36 +166,30 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_empty_target_text(self):
tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors=FRAMEWORK)
batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
# check if input_ids are returned and no decoder_input_ids
self.assertIn("input_ids", batch)
self.assertIn("attention_mask", batch)
self.assertNotIn("decoder_input_ids", batch)
self.assertNotIn("decoder_attention_mask", batch)
def test_max_target_length(self):
def test_max_length(self):
tokenizer = self.t5_base_tokenizer
src_text = ["A short paragraph for summarization.", "Another short paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors=FRAMEWORK
)
self.assertEqual(32, batch["labels"].shape[1])
# test None max_target_length
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors=FRAMEWORK
)
self.assertEqual(32, batch["labels"].shape[1])
with tokenizer.as_target_tokenizer():
targets = tokenizer(
tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
)
self.assertEqual(32, targets["input_ids"].shape[1])
def test_outputs_not_longer_than_maxlen(self):
tokenizer = self.t5_base_tokenizer
batch = tokenizer.prepare_seq2seq_batch(
["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK
batch = tokenizer(
["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors=FRAMEWORK
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 512))
......@@ -215,13 +201,12 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1]
expected_tgt_tokens = [20698, 13, 8, 1499, 5, 1]
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK)
src_ids = list(batch.input_ids.numpy()[0])
tgt_ids = list(batch.labels.numpy()[0])
batch = tokenizer(src_text)
with tokenizer.as_target_tokenizer():
targets = tokenizer(tgt_text)
self.assertEqual(expected_src_tokens, src_ids)
self.assertEqual(expected_tgt_tokens, tgt_ids)
self.assertEqual(expected_src_tokens, batch["input_ids"][0])
self.assertEqual(expected_tgt_tokens, targets["input_ids"][0])
def test_token_type_ids(self):
src_text_1 = ["A first paragraph for summarization."]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment