Unverified Commit 6326aa4b authored by Apoorv Garg's avatar Apoorv Garg Committed by GitHub
Browse files

Correct order of overflowing tokens for LayoutLmV2 tokenizer (#13495)



* correct order of overflowing tokens for LayoutLmV2 tokenizer

* test to check order of overflowing_tokens for a seq of input_ids

* fix up quality

* added suggested changes

* check that tests the bbox sequence

* pair_input test added

* pass quality test

* check bbox sequence added

* unittest method

* comments added

* add overflowing bbox test

* improved "seq_1"
Co-authored-by: default avatarSaulLu <55560583+SaulLu@users.noreply.github.com>

* improve code quality
Co-authored-by: default avatarSaulLu <lucilesaul.com@gmail.com>
Co-authored-by: default avatarSaulLu <55560583+SaulLu@users.noreply.github.com>
parent 95b3ec3b
......@@ -650,7 +650,7 @@ class LayoutLMv2Tokenizer(PreTrainedTokenizer):
"""
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
manages a moving window (with user defined stride) for overflowing tokens
manages a moving window (with user defined stride) for overflowing tokens.
Args:
batch_ids_pairs: list of tokenized input ids or input ids pairs
......@@ -893,7 +893,9 @@ class LayoutLMv2Tokenizer(PreTrainedTokenizer):
"""
Prepares a sequence or a pair of sequences so that it can be used by the model. It adds special tokens,
truncates sequences if overflowing while taking into account the special tokens and manages a moving window
(with user defined stride) for overflowing tokens.
(with user defined stride) for overflowing tokens. Please Note, for `text_pair` different than `None` and
`truncation_strategy = longest_first` or `True`, it is not possible to return overflowing tokens. Such a
combination of arguments will raise an error.
Word-level :obj:`boxes` are turned into token-level :obj:`bbox`. If provided, word-level :obj:`word_labels` are
turned into token-level :obj:`labels`. The word label is used for the first token of the word, while remaining
......@@ -963,6 +965,17 @@ class LayoutLMv2Tokenizer(PreTrainedTokenizer):
ids = self.convert_tokens_to_ids(tokens)
pair_ids = self.convert_tokens_to_ids(pair_tokens) if pair_tokens else None
if (
return_overflowing_tokens
and truncation_strategy == TruncationStrategy.LONGEST_FIRST
and pair_ids is not None
):
raise ValueError(
"Not possible to return overflowing tokens for pair of sequences with the "
"`longest_first`. Please select another truncation strategy than `longest_first`, "
"for instance `only_second` or `only_first`."
)
# Compute the total size of the returned encodings
pair = bool(pair_ids is not None)
len_ids = len(ids)
......@@ -1114,7 +1127,8 @@ class LayoutLMv2Tokenizer(PreTrainedTokenizer):
Returns:
:obj:`Tuple[List[int], List[int], List[int]]`: The truncated ``ids``, the truncated ``pair_ids`` and the
list of overflowing tokens.
list of overflowing tokens. Note: The `longest_first` strategy returns empty list of overflowing tokens if
a pair of sequences (or a batch of pairs) is provided.
"""
if num_tokens_to_remove <= 0:
return ids, token_boxes, pair_ids, pair_token_boxes, labels, [], [], []
......@@ -1125,29 +1139,9 @@ class LayoutLMv2Tokenizer(PreTrainedTokenizer):
overflowing_tokens = []
overflowing_token_boxes = []
overflowing_labels = []
if truncation_strategy == TruncationStrategy.LONGEST_FIRST:
for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids):
if not overflowing_tokens:
window_len = min(len(ids), stride + 1)
else:
window_len = 1
overflowing_tokens.extend(ids[-window_len:])
overflowing_token_boxes.extend(token_boxes[-window_len:])
overflowing_labels.extend(labels[-window_len:])
ids = ids[:-1]
token_boxes = token_boxes[:-1]
labels = labels[:-1]
else:
if not overflowing_tokens:
window_len = min(len(pair_ids), stride + 1)
else:
window_len = 1
overflowing_tokens.extend(pair_ids[-window_len:])
overflowing_token_boxes.extend(pair_token_boxes[-window_len:])
pair_ids = pair_ids[:-1]
pair_token_boxes = pair_token_boxes[:-1]
elif truncation_strategy == TruncationStrategy.ONLY_FIRST:
if truncation_strategy == TruncationStrategy.ONLY_FIRST or (
truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None
):
if len(ids) > num_tokens_to_remove:
window_len = min(len(ids), stride + num_tokens_to_remove)
overflowing_tokens = ids[-window_len:]
......@@ -1157,12 +1151,31 @@ class LayoutLMv2Tokenizer(PreTrainedTokenizer):
token_boxes = token_boxes[:-num_tokens_to_remove]
labels = labels[:-num_tokens_to_remove]
else:
logger.error(
error_msg = (
f"We need to remove {num_tokens_to_remove} to truncate the input "
f"but the first sequence has a length {len(ids)}. "
f"Please select another truncation strategy than {truncation_strategy}, "
f"for instance 'longest_first' or 'only_second'."
)
if truncation_strategy == TruncationStrategy.ONLY_FIRST:
error_msg = (
error_msg + "Please select another truncation strategy than "
f"{truncation_strategy}, for instance 'longest_first' or 'only_second'."
)
logger.error(error_msg)
elif truncation_strategy == TruncationStrategy.LONGEST_FIRST:
logger.warning(
f"Be aware, overflowing tokens are not returned for the setting you have chosen,"
f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' "
f"truncation strategy. So the returned list will always be empty even if some "
f"tokens have been removed."
)
for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids):
ids = ids[:-1]
token_boxes = token_boxes[:-1]
labels = labels[:-1]
else:
pair_ids = pair_ids[:-1]
pair_token_boxes = pair_token_boxes[:-1]
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
if len(pair_ids) > num_tokens_to_remove:
window_len = min(len(pair_ids), stride + num_tokens_to_remove)
......
......@@ -3015,7 +3015,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
Returns:
:obj:`Tuple[List[int], List[int], List[int]]`: The truncated ``ids``, the truncated ``pair_ids`` and the
list of overflowing tokens. Note: The `longest_first` strategy returns empty list of overflowing_tokens if
list of overflowing tokens. Note: The `longest_first` strategy returns empty list of overflowing tokens if
a pair of sequences (or a batch of pairs) is provided.
"""
if num_tokens_to_remove <= 0:
......
......@@ -15,6 +15,7 @@
import inspect
import os
import re
import shutil
import tempfile
import unittest
......@@ -1777,13 +1778,515 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_alignement_methods(self):
pass
@unittest.skip("LayoutLMv2 tokenizer requires boxes besides sequences.")
def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20, min_length=5):
toks = [(i, tokenizer.decode([i], clean_up_tokenization_spaces=False)) for i in range(len(tokenizer))]
toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks))
toks = list(
filter(
lambda t: [t[0]]
== tokenizer.encode(t[1].split(" "), boxes=len(t[1]) * [[1, 1, 1, 1]], add_special_tokens=False),
toks,
)
)
if max_length is not None and len(toks) > max_length:
toks = toks[:max_length]
if min_length is not None and len(toks) < min_length and len(toks) > 0:
while len(toks) < min_length:
toks = toks + toks
# toks_str = [t[1] for t in toks]
toks_ids = [t[0] for t in toks]
# Ensure consistency
output_txt = tokenizer.decode(toks_ids, clean_up_tokenization_spaces=False)
if " " not in output_txt and len(toks_ids) > 1:
output_txt = (
tokenizer.decode([toks_ids[0]], clean_up_tokenization_spaces=False)
+ " "
+ tokenizer.decode(toks_ids[1:], clean_up_tokenization_spaces=False)
)
if with_prefix_space:
output_txt = " " + output_txt
words = output_txt.split(" ")
boxes = [[i, i, i, i] for i in range(len(words))]
output_ids = tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
return words, boxes, output_ids
# @unittest.skip("LayoutLMv2 tokenizer requires boxes besides sequences.")
def test_maximum_encoding_length_pair_input(self):
pass
tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
# Build a sequence from our model's vocabulary
stride = 2
seq_0, boxes_0, ids = self.get_clean_sequence(tokenizer, max_length=20)
question_0 = " ".join(map(str, seq_0))
if len(ids) <= 2 + stride:
seq_0 = (seq_0 + " ") * (2 + stride)
ids = None
seq0_tokens = tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)
self.assertGreater(len(seq0_tokens["input_ids"]), 2 + stride)
question_1 = "This is another sentence to be encoded."
seq_1 = ["what", "a", "weird", "test", "weirdly", "weird"]
boxes_1 = [[i, i, i, i] for i in range(len(seq_1))]
seq1_tokens = tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)
if abs(len(seq0_tokens["input_ids"]) - len(seq1_tokens["input_ids"])) <= 2:
seq1_tokens_input_ids = seq1_tokens["input_ids"] + seq1_tokens["input_ids"]
seq_1 = tokenizer.decode(seq1_tokens_input_ids, clean_up_tokenization_spaces=False)
seq_1 = seq_1.split(" ")
boxes_1 = [[i, i, i, i] for i in range(len(seq_1))]
seq1_tokens = tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)
self.assertGreater(len(seq1_tokens["input_ids"]), 2 + stride)
smallest = (
seq1_tokens["input_ids"]
if len(seq0_tokens["input_ids"]) > len(seq1_tokens["input_ids"])
else seq0_tokens["input_ids"]
)
@unittest.skip("LayoutLMv2 tokenizer requires boxes besides sequences.")
# We are not using the special tokens - a bit too hard to test all the tokenizers with this
# TODO try this again later
sequence = tokenizer(
question_0, seq_1, boxes=boxes_1, add_special_tokens=False
) # , add_prefix_space=False)
# Test with max model input length
model_max_length = tokenizer.model_max_length
self.assertEqual(model_max_length, 100)
seq_2 = seq_0 * model_max_length
question_2 = " ".join(map(str, seq_2))
boxes_2 = boxes_0 * model_max_length
self.assertGreater(len(seq_2), model_max_length)
sequence1 = tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)
total_length1 = len(sequence1["input_ids"])
sequence2 = tokenizer(question_2, seq_1, boxes=boxes_1, add_special_tokens=False)
total_length2 = len(sequence2["input_ids"])
self.assertLess(total_length1, model_max_length, "Issue with the testing sequence, please update it.")
self.assertGreater(
total_length2, model_max_length, "Issue with the testing sequence, please update it."
)
# Simple
padding_strategies = (
[False, True, "longest"] if tokenizer.pad_token and tokenizer.pad_token_id >= 0 else [False]
)
for padding_state in padding_strategies:
with self.subTest(f"{tokenizer.__class__.__name__} Padding: {padding_state}"):
for truncation_state in [True, "longest_first", "only_first"]:
with self.subTest(f"{tokenizer.__class__.__name__} Truncation: {truncation_state}"):
output = tokenizer(
question_2,
seq_1,
boxes=boxes_1,
padding=padding_state,
truncation=truncation_state,
)
self.assertEqual(len(output["input_ids"]), model_max_length)
self.assertEqual(len(output["bbox"]), model_max_length)
output = tokenizer(
[question_2],
[seq_1],
boxes=[boxes_1],
padding=padding_state,
truncation=truncation_state,
)
self.assertEqual(len(output["input_ids"][0]), model_max_length)
self.assertEqual(len(output["bbox"][0]), model_max_length)
# Simple
output = tokenizer(
question_1, seq_2, boxes=boxes_2, padding=padding_state, truncation="only_second"
)
self.assertEqual(len(output["input_ids"]), model_max_length)
self.assertEqual(len(output["bbox"]), model_max_length)
output = tokenizer(
[question_1], [seq_2], boxes=[boxes_2], padding=padding_state, truncation="only_second"
)
self.assertEqual(len(output["input_ids"][0]), model_max_length)
self.assertEqual(len(output["bbox"][0]), model_max_length)
# Simple with no truncation
# Reset warnings
tokenizer.deprecation_warnings = {}
with self.assertLogs("transformers", level="WARNING") as cm:
output = tokenizer(
question_1, seq_2, boxes=boxes_2, padding=padding_state, truncation=False
)
self.assertNotEqual(len(output["input_ids"]), model_max_length)
self.assertNotEqual(len(output["bbox"]), model_max_length)
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
"Token indices sequence length is longer than the specified maximum sequence length for this model"
)
)
tokenizer.deprecation_warnings = {}
with self.assertLogs("transformers", level="WARNING") as cm:
output = tokenizer(
[question_1], [seq_2], boxes=[boxes_2], padding=padding_state, truncation=False
)
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
self.assertNotEqual(len(output["bbox"][0]), model_max_length)
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
"Token indices sequence length is longer than the specified maximum sequence length for this model"
)
)
# Check the order of Sequence of input ids, overflowing tokens and bbox sequence with truncation
truncated_first_sequence = (
tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)["input_ids"][:-2]
+ tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["input_ids"]
)
truncated_second_sequence = (
tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)["input_ids"]
+ tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["input_ids"][:-2]
)
truncated_longest_sequence = (
truncated_first_sequence if len(seq0_tokens) > len(seq1_tokens) else truncated_second_sequence
)
overflow_first_sequence = (
tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)["input_ids"][-(2 + stride) :]
+ tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["input_ids"]
)
overflow_second_sequence = (
tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)["input_ids"]
+ tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["input_ids"][-(2 + stride) :]
)
overflow_longest_sequence = (
overflow_first_sequence if len(seq0_tokens) > len(seq1_tokens) else overflow_second_sequence
)
bbox_first = [[0, 0, 0, 0]] * (len(seq_0) - 2)
bbox_first_sequence = bbox_first + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["bbox"]
overflowing_token_bbox_first_sequence_slow = [[0, 0, 0, 0]] * (2 + stride)
overflowing_token_bbox_first_sequence_fast = [[0, 0, 0, 0]] * (2 + stride) + tokenizer(
seq_1, boxes=boxes_1, add_special_tokens=False
)["bbox"]
bbox_second = [[0, 0, 0, 0]] * len(seq_0)
bbox_second_sequence = (
bbox_second + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["bbox"][:-2]
)
overflowing_token_bbox_second_sequence_slow = tokenizer(
seq_1, boxes=boxes_1, add_special_tokens=False
)["bbox"][-(2 + stride) :]
overflowing_token_bbox_second_sequence_fast = [[0, 0, 0, 0]] * len(seq_0) + tokenizer(
seq_1, boxes=boxes_1, add_special_tokens=False
)["bbox"][-(2 + stride) :]
bbox_longest_sequence = (
bbox_first_sequence if len(seq0_tokens) > len(seq1_tokens) else bbox_second_sequence
)
overflowing_token_bbox_longest_sequence_fast = (
overflowing_token_bbox_first_sequence_fast
if len(seq0_tokens) > len(seq1_tokens)
else overflowing_token_bbox_second_sequence_fast
)
# Overflowing tokens are handled quite differently in slow and fast tokenizers
if isinstance(tokenizer, LayoutLMv2TokenizerFast):
information = tokenizer(
question_0,
seq_1,
boxes=boxes_1,
max_length=len(sequence["input_ids"]) - 2,
add_special_tokens=False,
stride=stride,
truncation="longest_first",
return_overflowing_tokens=True,
# add_prefix_space=False,
)
truncated_sequence = information["input_ids"][0]
overflowing_tokens = information["input_ids"][1]
bbox = information["bbox"][0]
overflowing_bbox = information["bbox"][1]
self.assertEqual(len(information["input_ids"]), 2)
self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
self.assertEqual(truncated_sequence, truncated_longest_sequence)
self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest))
self.assertEqual(overflowing_tokens, overflow_longest_sequence)
self.assertEqual(bbox, bbox_longest_sequence)
self.assertEqual(len(overflowing_bbox), 2 + stride + len(smallest))
self.assertEqual(overflowing_bbox, overflowing_token_bbox_longest_sequence_fast)
else:
# No overflowing tokens when using 'longest' in python tokenizers
with self.assertRaises(ValueError) as context:
information = tokenizer(
question_0,
seq_1,
boxes=boxes_1,
max_length=len(sequence["input_ids"]) - 2,
add_special_tokens=False,
stride=stride,
truncation="longest_first",
return_overflowing_tokens=True,
# add_prefix_space=False,
)
self.assertTrue(
context.exception.args[0].startswith(
"Not possible to return overflowing tokens for pair of sequences with the "
"`longest_first`. Please select another truncation strategy than `longest_first`, "
"for instance `only_second` or `only_first`."
)
)
# Overflowing tokens are handled quite differently in slow and fast tokenizers
if isinstance(tokenizer, LayoutLMv2TokenizerFast):
information = tokenizer(
question_0,
seq_1,
boxes=boxes_1,
max_length=len(sequence["input_ids"]) - 2,
add_special_tokens=False,
stride=stride,
truncation=True,
return_overflowing_tokens=True,
# add_prefix_space=False,
)
truncated_sequence = information["input_ids"][0]
overflowing_tokens = information["input_ids"][1]
bbox = information["bbox"][0]
overflowing_bbox = information["bbox"][1]
self.assertEqual(len(information["input_ids"]), 2)
self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
self.assertEqual(truncated_sequence, truncated_longest_sequence)
self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest))
self.assertEqual(overflowing_tokens, overflow_longest_sequence)
self.assertEqual(bbox, bbox_longest_sequence)
self.assertEqual(overflowing_bbox, overflowing_token_bbox_longest_sequence_fast)
else:
# No overflowing tokens when using 'longest' in python tokenizers
with self.assertRaises(ValueError) as context:
information = tokenizer(
question_0,
seq_1,
boxes=boxes_1,
max_length=len(sequence["input_ids"]) - 2,
add_special_tokens=False,
stride=stride,
truncation=True,
return_overflowing_tokens=True,
# add_prefix_space=False,
)
self.assertTrue(
context.exception.args[0].startswith(
"Not possible to return overflowing tokens for pair of sequences with the "
"`longest_first`. Please select another truncation strategy than `longest_first`, "
"for instance `only_second` or `only_first`."
)
)
information_first_truncated = tokenizer(
question_0,
seq_1,
boxes=boxes_1,
max_length=len(sequence["input_ids"]) - 2,
add_special_tokens=False,
stride=stride,
truncation="only_first",
return_overflowing_tokens=True,
# add_prefix_space=False,
)
# Overflowing tokens are handled quite differently in slow and fast tokenizers
if isinstance(tokenizer, LayoutLMv2TokenizerFast):
truncated_sequence = information_first_truncated["input_ids"][0]
overflowing_tokens = information_first_truncated["input_ids"][1]
bbox = information_first_truncated["bbox"][0]
overflowing_bbox = information_first_truncated["bbox"][1]
self.assertEqual(len(information_first_truncated["input_ids"]), 2)
self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
self.assertEqual(truncated_sequence, truncated_first_sequence)
self.assertEqual(len(overflowing_tokens), 2 + stride + len(seq1_tokens["input_ids"]))
self.assertEqual(overflowing_tokens, overflow_first_sequence)
self.assertEqual(bbox, bbox_first_sequence)
self.assertEqual(overflowing_bbox, overflowing_token_bbox_first_sequence_fast)
else:
truncated_sequence = information_first_truncated["input_ids"]
overflowing_tokens = information_first_truncated["overflowing_tokens"]
overflowing_bbox = information_first_truncated["overflowing_token_boxes"]
bbox = information_first_truncated["bbox"]
self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
self.assertEqual(truncated_sequence, truncated_first_sequence)
self.assertEqual(len(overflowing_tokens), 2 + stride)
self.assertEqual(overflowing_tokens, seq0_tokens["input_ids"][-(2 + stride) :])
self.assertEqual(bbox, bbox_first_sequence)
self.assertEqual(overflowing_bbox, overflowing_token_bbox_first_sequence_slow)
information_second_truncated = tokenizer(
question_0,
seq_1,
boxes=boxes_1,
max_length=len(sequence["input_ids"]) - 2,
add_special_tokens=False,
stride=stride,
truncation="only_second",
return_overflowing_tokens=True,
# add_prefix_space=False,
)
# Overflowing tokens are handled quite differently in slow and fast tokenizers
if isinstance(tokenizer, LayoutLMv2TokenizerFast):
truncated_sequence = information_second_truncated["input_ids"][0]
overflowing_tokens = information_second_truncated["input_ids"][1]
bbox = information_second_truncated["bbox"][0]
overflowing_bbox = information_second_truncated["bbox"][1]
self.assertEqual(len(information_second_truncated["input_ids"]), 2)
self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
self.assertEqual(truncated_sequence, truncated_second_sequence)
self.assertEqual(len(overflowing_tokens), 2 + stride + len(seq0_tokens["input_ids"]))
self.assertEqual(overflowing_tokens, overflow_second_sequence)
self.assertEqual(bbox, bbox_second_sequence)
self.assertEqual(overflowing_bbox, overflowing_token_bbox_second_sequence_fast)
else:
truncated_sequence = information_second_truncated["input_ids"]
overflowing_tokens = information_second_truncated["overflowing_tokens"]
bbox = information_second_truncated["bbox"]
overflowing_bbox = information_second_truncated["overflowing_token_boxes"]
self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
self.assertEqual(truncated_sequence, truncated_second_sequence)
self.assertEqual(len(overflowing_tokens), 2 + stride)
self.assertEqual(overflowing_tokens, seq1_tokens["input_ids"][-(2 + stride) :])
self.assertEqual(bbox, bbox_second_sequence)
self.assertEqual(overflowing_bbox, overflowing_token_bbox_second_sequence_slow)
# @unittest.skip("LayoutLMv2 tokenizer requires boxes besides sequences.")
def test_maximum_encoding_length_single_input(self):
pass
tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
seq_0, boxes_0, ids = self.get_clean_sequence(tokenizer, max_length=20)
sequence = tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)
total_length = len(sequence["input_ids"])
self.assertGreater(total_length, 4, "Issue with the testing sequence, please update it it's too short")
# Test with max model input length
model_max_length = tokenizer.model_max_length
self.assertEqual(model_max_length, 100)
seq_1 = seq_0 * model_max_length
boxes_1 = boxes_0 * model_max_length
sequence1 = tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)
total_length1 = len(sequence1["input_ids"])
self.assertGreater(
total_length1, model_max_length, "Issue with the testing sequence, please update it it's too short"
)
# Simple
padding_strategies = (
[False, True, "longest"] if tokenizer.pad_token and tokenizer.pad_token_id >= 0 else [False]
)
for padding_state in padding_strategies:
with self.subTest(f"Padding: {padding_state}"):
for truncation_state in [True, "longest_first", "only_first"]:
with self.subTest(f"Truncation: {truncation_state}"):
output = tokenizer(
seq_1,
boxes=boxes_1,
padding=padding_state,
truncation=truncation_state,
)
self.assertEqual(len(output["input_ids"]), model_max_length)
self.assertEqual(len(output["bbox"]), model_max_length)
output = tokenizer(
[seq_1],
boxes=[boxes_1],
padding=padding_state,
truncation=truncation_state,
)
self.assertEqual(len(output["input_ids"][0]), model_max_length)
self.assertEqual(len(output["bbox"][0]), model_max_length)
# Simple with no truncation
# Reset warnings
tokenizer.deprecation_warnings = {}
with self.assertLogs("transformers", level="WARNING") as cm:
output = tokenizer(seq_1, boxes=boxes_1, padding=padding_state, truncation=False)
self.assertNotEqual(len(output["input_ids"]), model_max_length)
self.assertNotEqual(len(output["bbox"]), model_max_length)
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
"Token indices sequence length is longer than the specified maximum sequence length for this model"
)
)
tokenizer.deprecation_warnings = {}
with self.assertLogs("transformers", level="WARNING") as cm:
output = tokenizer([seq_1], boxes=[boxes_1], padding=padding_state, truncation=False)
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
self.assertNotEqual(len(output["bbox"][0]), model_max_length)
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
"Token indices sequence length is longer than the specified maximum sequence length for this model"
)
)
# Check the order of Sequence of input ids, overflowing tokens and bbox sequence with truncation
stride = 2
information = tokenizer(
seq_0,
boxes=boxes_0,
max_length=total_length - 2,
add_special_tokens=False,
stride=stride,
truncation=True,
return_overflowing_tokens=True,
# add_prefix_space=False,
)
# Overflowing tokens are handled quite differently in slow and fast tokenizers
if isinstance(tokenizer, LayoutLMv2TokenizerFast):
truncated_sequence = information["input_ids"][0]
overflowing_tokens = information["input_ids"][1]
bbox = information["bbox"][0]
overflowing_bbox = information["bbox"][1]
self.assertEqual(len(information["input_ids"]), 2)
self.assertEqual(len(truncated_sequence), total_length - 2)
self.assertEqual(truncated_sequence, sequence["input_ids"][:-2])
self.assertEqual(len(overflowing_tokens), 2 + stride)
self.assertEqual(overflowing_tokens, sequence["input_ids"][-(2 + stride) :])
self.assertEqual(bbox, sequence["bbox"][:-2])
self.assertEqual(overflowing_bbox, sequence["bbox"][-(2 + stride) :])
else:
truncated_sequence = information["input_ids"]
overflowing_tokens = information["overflowing_tokens"]
bbox = information["bbox"]
overflowing_bbox = information["overflowing_token_boxes"]
self.assertEqual(len(truncated_sequence), total_length - 2)
self.assertEqual(truncated_sequence, sequence["input_ids"][:-2])
self.assertEqual(len(overflowing_tokens), 2 + stride)
self.assertEqual(overflowing_tokens, sequence["input_ids"][-(2 + stride) :])
self.assertEqual(bbox, sequence["bbox"][:-2])
self.assertEqual(overflowing_bbox, sequence["bbox"][-(2 + stride) :])
@unittest.skip("LayoutLMv2 tokenizer requires boxes besides sequences.")
def test_pretokenized_inputs(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment