Unverified Commit 986526a0 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Replace `as_target` context managers by direct calls (#18325)



* Preliminary work on tokenizers

* Quality + fix tests

* Treat processors

* Fix pad

* Remove all uses of  in tests, docs and examples

* Replace all as_target_tokenizer

* Fix tests

* Fix quality

* Update examples/flax/image-captioning/run_image_captioning_flax.py
Co-authored-by: default avataramyeroberts <amy@huggingface.co>

* Style
Co-authored-by: default avataramyeroberts <amy@huggingface.co>
parent a64bcb56
......@@ -112,14 +112,13 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
self.assertNotIn("decoder_attention_mask", batch)
@require_torch
def test_as_target_tokenizer_target_length(self):
def test_tokenizer_as_target_length(self):
tgt_text = [
"Summary of the text.",
"Another summary.",
]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
with tokenizer.as_target_tokenizer():
targets = tokenizer(tgt_text, max_length=32, padding="max_length", return_tensors="pt")
targets = tokenizer(text_target=tgt_text, max_length=32, padding="max_length", return_tensors="pt")
self.assertEqual(32, targets["input_ids"].shape[1])
@require_torch
......@@ -140,8 +139,7 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
inputs = tokenizer(src_text, return_tensors="pt")
with tokenizer.as_target_tokenizer():
targets = tokenizer(tgt_text, return_tensors="pt")
targets = tokenizer(text_target=tgt_text, return_tensors="pt")
input_ids = inputs["input_ids"]
labels = targets["input_ids"]
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
......
......@@ -152,10 +152,9 @@ class ByT5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
"Summary of the text.",
"Another summary.",
]
with tokenizer.as_target_tokenizer():
targets = tokenizer(
tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
)
targets = tokenizer(
text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
)
self.assertEqual(32, targets["input_ids"].shape[1])
def test_eos_in_input(self):
......@@ -167,12 +166,10 @@ class ByT5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
expected_tgt_tokens = [86, 120, 112, 112, 100, 117, 124, 35, 114, 105, 35, 119, 107, 104, 35, 119, 104, 123, 119, 49, 35, 1]
# fmt: on
batch = tokenizer(src_text)
with tokenizer.as_target_tokenizer():
targets = tokenizer(tgt_text)
batch = tokenizer(src_text, text_target=tgt_text)
self.assertEqual(expected_src_tokens, batch["input_ids"][0])
self.assertEqual(expected_tgt_tokens, targets["input_ids"][0])
self.assertEqual(expected_tgt_tokens, batch["labels"][0])
# cannot use default save_and_load_tokenzier test method because tokenzier has no vocab
def test_save_and_load_tokenizer(self):
......
......@@ -80,8 +80,9 @@ class CanineTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
"What's the weater?",
"It's about 25 degrees.",
]
with tokenizer.as_target_tokenizer():
targets = tokenizer(tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors="pt")
targets = tokenizer(
text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors="pt"
)
self.assertEqual(32, targets["input_ids"].shape[1])
# cannot use default save_and_load_tokenzier test method because tokenzier has no vocab
......
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import (
DPRContextEncoderTokenizer,
DPRContextEncoderTokenizerFast,
......
......@@ -187,9 +187,7 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase):
self.tokenizer.src_lang = "en"
self.tokenizer.tgt_lang = "fr"
batch = self.tokenizer(self.src_text, padding=True, return_tensors="pt")
with self.tokenizer.as_target_tokenizer():
batch["labels"] = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt").input_ids
batch = self.tokenizer(self.src_text, text_target=self.tgt_text, padding=True, return_tensors="pt")
batch["decoder_input_ids"] = shift_tokens_right(
batch["labels"], self.tokenizer.pad_token_id, self.tokenizer.eos_token_id
......@@ -217,17 +215,19 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase):
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
@require_torch
def test_as_target_tokenizer(self):
def test_tokenizer_target_mode(self):
self.tokenizer.tgt_lang = "mr"
with self.tokenizer.as_target_tokenizer():
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("mr")])
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
self.tokenizer._switch_to_target_mode()
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("mr")])
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
self.tokenizer._switch_to_input_mode()
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)])
self.tokenizer.tgt_lang = "zh"
with self.tokenizer.as_target_tokenizer():
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")])
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
self.tokenizer._switch_to_target_mode()
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")])
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
self.tokenizer._switch_to_input_mode()
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)])
@require_torch
......
......@@ -438,10 +438,7 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]
expected_ids = [38, 121, 14, 697, 38848, 0]
model_inputs = self.tokenizer(src, return_tensors="pt").to(torch_device)
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(tgt, return_tensors="pt")
model_inputs["labels"] = targets["input_ids"].to(torch_device)
model_inputs = self.tokenizer(src, text_target=tgt, return_tensors="pt").to(torch_device)
self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
......
......@@ -145,9 +145,8 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
src_ids = tokenizer(source_text).input_ids
self.assertListEqual(src_ids, expected_src_ids)
with tokenizer.as_target_tokenizer():
target_ids = tokenizer(target_text).input_ids
self.assertListEqual(target_ids, expected_target_ids)
target_ids = tokenizer(text_target=target_text).input_ids
self.assertListEqual(target_ids, expected_target_ids)
decoded = tokenizer.decode(target_ids, skip_special_tokens=True)
self.assertEqual(decoded, target_text)
......@@ -265,33 +265,27 @@ class MBartEnroIntegrationTest(unittest.TestCase):
@require_torch
def test_batch_fairseq_parity(self):
batch = self.tokenizer(self.src_text, padding=True)
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
batch = self.tokenizer(self.src_text, text_target=self.tgt_text, padding=True, return_tensors="pt")
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
assert batch.input_ids[1][-2:] == [2, EN_CODE]
assert batch.decoder_input_ids[1][0] == RO_CODE
assert batch.input_ids[1][-2:].tolist() == [2, EN_CODE]
assert batch.decoder_input_ids[1][0].tolist() == RO_CODE
assert batch.decoder_input_ids[1][-1] == 2
assert labels[1][-2:].tolist() == [2, RO_CODE]
assert batch.labels[1][-2:].tolist() == [2, RO_CODE]
@require_torch
def test_enro_tokenizer_prepare_batch(self):
batch = self.tokenizer(
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
self.src_text,
text_target=self.tgt_text,
padding=True,
truncation=True,
max_length=len(self.expected_src_tokens),
return_tensors="pt",
)
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(
self.tgt_text,
padding=True,
truncation=True,
max_length=len(self.expected_src_tokens),
return_tensors="pt",
)
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
self.assertIsInstance(batch, BatchEncoding)
......@@ -306,8 +300,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
def test_seq2seq_max_length(self):
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
targets = self.tokenizer(
text_target=self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt"
)
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
......
......@@ -256,35 +256,27 @@ class MBart50OneToManyIntegrationTest(unittest.TestCase):
@require_torch
def test_batch_fairseq_parity(self):
batch = self.tokenizer(self.src_text, padding=True)
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
labels = labels.tolist()
batch = self.tokenizer(self.src_text, text_target=self.tgt_text, padding=True, return_tensors="pt")
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
assert batch.input_ids[1][0] == EN_CODE
assert batch.input_ids[1][-1] == 2
assert labels[1][0] == RO_CODE
assert labels[1][-1] == 2
assert batch.decoder_input_ids[1][:2] == [2, RO_CODE]
assert batch.labels[1][0] == RO_CODE
assert batch.labels[1][-1] == 2
assert batch.decoder_input_ids[1][:2].tolist() == [2, RO_CODE]
@require_torch
def test_tokenizer_prepare_batch(self):
batch = self.tokenizer(
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
self.src_text,
text_target=self.tgt_text,
padding=True,
truncation=True,
max_length=len(self.expected_src_tokens),
return_tensors="pt",
)
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(
self.tgt_text,
padding=True,
truncation=True,
max_length=len(self.expected_src_tokens),
return_tensors="pt",
)
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
self.assertIsInstance(batch, BatchEncoding)
......@@ -299,8 +291,9 @@ class MBart50OneToManyIntegrationTest(unittest.TestCase):
def test_seq2seq_max_target_length(self):
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
targets = self.tokenizer(
text_target=self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt"
)
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
......
......@@ -125,8 +125,7 @@ class MCTCTProcessorTest(unittest.TestCase):
input_str = "This is a test string"
with processor.as_target_processor():
encoded_processor = processor(input_str)
encoded_processor = processor(text=input_str)
encoded_tok = tokenizer(input_str)
......
......@@ -112,14 +112,13 @@ class TestTokenizationMvp(TokenizerTesterMixin, unittest.TestCase):
self.assertNotIn("decoder_attention_mask", batch)
@require_torch
def test_as_target_tokenizer_target_length(self):
def test_tokenizer_as_target_length(self):
tgt_text = [
"Summary of the text.",
"Another summary.",
]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
with tokenizer.as_target_tokenizer():
targets = tokenizer(tgt_text, max_length=32, padding="max_length", return_tensors="pt")
targets = tokenizer(text_target=tgt_text, max_length=32, padding="max_length", return_tensors="pt")
self.assertEqual(32, targets["input_ids"].shape[1])
@require_torch
......@@ -139,11 +138,9 @@ class TestTokenizationMvp(TokenizerTesterMixin, unittest.TestCase):
"Summary of the text.",
]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
inputs = tokenizer(src_text, return_tensors="pt")
with tokenizer.as_target_tokenizer():
targets = tokenizer(tgt_text, return_tensors="pt")
inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")
input_ids = inputs["input_ids"]
labels = targets["input_ids"]
labels = inputs["labels"]
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((labels[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item())
......
......@@ -373,19 +373,15 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
@require_torch
def test_enro_tokenizer_prepare_batch(self):
batch = self.tokenizer(
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
self.src_text,
text_target=self.tgt_text,
padding=True,
truncation=True,
max_length=len(self.expected_src_tokens),
return_tensors="pt",
)
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(
self.tgt_text,
padding=True,
truncation=True,
max_length=len(self.expected_src_tokens),
return_tensors="pt",
)
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(
labels, self.tokenizer.pad_token_id, self.tokenizer.lang_code_to_id["ron_Latn"]
batch["labels"], self.tokenizer.pad_token_id, self.tokenizer.lang_code_to_id["ron_Latn"]
)
self.assertIsInstance(batch, BatchEncoding)
......@@ -401,8 +397,9 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
def test_seq2seq_max_length(self):
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
targets = self.tokenizer(
text_target=self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt"
)
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(
labels,
......
......@@ -109,10 +109,9 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
src_texts = ["This is going to be way too long." * 150, "short example"]
tgt_texts = ["not super long but more than 5 tokens", "tiny"]
batch = self._large_tokenizer(src_texts, padding=True, truncation=True, return_tensors="pt")
with self._large_tokenizer.as_target_tokenizer():
targets = self._large_tokenizer(
tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
)
targets = self._large_tokenizer(
text_target=tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
)
assert batch.input_ids.shape == (2, 1024)
assert batch.attention_mask.shape == (2, 1024)
......@@ -174,10 +173,9 @@ class BigBirdPegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
src_texts = ["This is going to be way too long." * 1000, "short example"]
tgt_texts = ["not super long but more than 5 tokens", "tiny"]
batch = self._large_tokenizer(src_texts, padding=True, truncation=True, return_tensors="pt")
with self._large_tokenizer.as_target_tokenizer():
targets = self._large_tokenizer(
tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
)
targets = self._large_tokenizer(
text_target=tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
)
assert batch.input_ids.shape == (2, 4096)
assert batch.attention_mask.shape == (2, 4096)
......
......@@ -146,10 +146,9 @@ class PerceiverTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
"Summary of the text.",
"Another summary.",
]
with tokenizer.as_target_tokenizer():
targets = tokenizer(
tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
)
targets = tokenizer(
text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
)
self.assertEqual(32, targets["input_ids"].shape[1])
# cannot use default save_and_load_tokenzier test method because tokenzier has no vocab
......
......@@ -299,33 +299,26 @@ class PLBartPythonEnIntegrationTest(unittest.TestCase):
@require_torch
def test_batch_fairseq_parity(self):
batch = self.tokenizer(self.src_text, padding=True)
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
batch = self.tokenizer(self.src_text, text_target=self.tgt_text, padding=True, return_tensors="pt")
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
self.assertEqual(batch.input_ids[1][-2:], [2, PYTHON_CODE])
self.assertEqual(batch.input_ids[1][-2:].tolist(), [2, PYTHON_CODE])
self.assertEqual(batch.decoder_input_ids[1][0], EN_CODE)
self.assertEqual(batch.decoder_input_ids[1][-1], 2)
self.assertEqual(labels[1][-2:].tolist(), [2, EN_CODE])
self.assertEqual(batch.labels[1][-2:].tolist(), [2, EN_CODE])
@require_torch
def test_python_en_tokenizer_prepare_batch(self):
batch = self.tokenizer(
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
self.src_text,
text_target=self.tgt_text,
padding=True,
truncation=True,
max_length=len(self.expected_src_tokens),
return_tensors="pt",
)
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(
self.tgt_text,
padding=True,
truncation=True,
max_length=len(self.expected_src_tokens),
return_tensors="pt",
)
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
self.assertIsInstance(batch, BatchEncoding)
......@@ -340,8 +333,9 @@ class PLBartPythonEnIntegrationTest(unittest.TestCase):
def test_seq2seq_max_length(self):
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
targets = self.tokenizer(
text_target=self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt"
)
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
......
......@@ -125,8 +125,7 @@ class Speech2TextProcessorTest(unittest.TestCase):
input_str = "This is a test string"
with processor.as_target_processor():
encoded_processor = processor(input_str)
encoded_processor = processor(text=input_str)
encoded_tok = tokenizer(input_str)
......
......@@ -210,10 +210,9 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
"Summary of the text.",
"Another summary.",
]
with tokenizer.as_target_tokenizer():
targets = tokenizer(
tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
)
targets = tokenizer(
text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
)
self.assertEqual(32, targets["input_ids"].shape[1])
def test_outputs_not_longer_than_maxlen(self):
......@@ -235,12 +234,10 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1]
expected_tgt_tokens = [20698, 13, 8, 1499, 5, 1]
batch = tokenizer(src_text)
with tokenizer.as_target_tokenizer():
targets = tokenizer(tgt_text)
batch = tokenizer(src_text, text_target=tgt_text)
self.assertEqual(expected_src_tokens, batch["input_ids"][0])
self.assertEqual(expected_tgt_tokens, targets["input_ids"][0])
self.assertEqual(expected_tgt_tokens, batch["labels"][0])
def test_token_type_ids(self):
src_text_1 = ["A first paragraph for summarization."]
......
......@@ -859,9 +859,8 @@ class TapexTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = TapexTokenizer.from_pretrained("microsoft/tapex-base")
answer_text = "tapex is a good model!"
expected_src_tokens = [0, 90, 5776, 1178, 16, 10, 205, 1421, 328, 2]
with tokenizer.as_target_tokenizer():
answer_encoding = tokenizer(answer=answer_text)
self.assertListEqual(answer_encoding.input_ids, expected_src_tokens)
answer_encoding = tokenizer(answer=answer_text)
self.assertListEqual(answer_encoding.input_ids, expected_src_tokens)
@slow
def test_tokenizer_lower_case(self):
......@@ -870,23 +869,21 @@ class TapexTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
answer_text = "Beijing, London, Paris"
answer_text_lower = "beijing, london, paris"
with cased_tokenizer.as_target_tokenizer():
with uncased_tokenizer.as_target_tokenizer():
self.assertNotEqual(
cased_tokenizer(answer=answer_text).input_ids, uncased_tokenizer(answer=answer_text).input_ids
)
self.assertEqual(
cased_tokenizer(answer=answer_text_lower).input_ids,
uncased_tokenizer(answer=answer_text).input_ids,
)
# batched encoding assert
self.assertNotEqual(
cased_tokenizer(answer=[answer_text]).input_ids, uncased_tokenizer(answer=[answer_text]).input_ids
)
self.assertEqual(
cased_tokenizer(answer=[answer_text_lower]).input_ids,
uncased_tokenizer(answer=[answer_text]).input_ids,
)
self.assertNotEqual(
cased_tokenizer(answer=answer_text).input_ids, uncased_tokenizer(answer=answer_text).input_ids
)
self.assertEqual(
cased_tokenizer(answer=answer_text_lower).input_ids,
uncased_tokenizer(answer=answer_text).input_ids,
)
# batched encoding assert
self.assertNotEqual(
cased_tokenizer(answer=[answer_text]).input_ids, uncased_tokenizer(answer=[answer_text]).input_ids
)
self.assertEqual(
cased_tokenizer(answer=[answer_text_lower]).input_ids,
uncased_tokenizer(answer=[answer_text]).input_ids,
)
# test input encoding lowercase
question = "Greece held its last Summer Olympics in 2004"
table_dict = {
......
......@@ -118,8 +118,7 @@ class Wav2Vec2ProcessorTest(unittest.TestCase):
input_str = "This is a test string"
with processor.as_target_processor():
encoded_processor = processor(input_str)
encoded_processor = processor(text=input_str)
encoded_tok = tokenizer(input_str)
......
......@@ -164,8 +164,7 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
input_str = "This is a test string"
with processor.as_target_processor():
encoded_processor = processor(input_str)
encoded_processor = processor(text=input_str)
encoded_tok = tokenizer(input_str)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment