Unverified Commit 6326aa4b authored by Apoorv Garg's avatar Apoorv Garg Committed by GitHub
Browse files

Correct order of overflowing tokens for LayoutLmV2 tokenizer (#13495)



* correct order of overflowing tokens for LayoutLmV2 tokenizer

* test to check order of overflowing_tokens for a seq of input_ids

* fix up quality

* added suggested changes

* check that tests the bbox sequence

* pair_input test added

* pass quality test

* check bbox sequence added

* unittest method

* comments added

* add overflowing bbox test

* improved "seq_1"
Co-authored-by: default avatarSaulLu <55560583+SaulLu@users.noreply.github.com>

* improve code quality
Co-authored-by: default avatarSaulLu <lucilesaul.com@gmail.com>
Co-authored-by: default avatarSaulLu <55560583+SaulLu@users.noreply.github.com>
parent 95b3ec3b
......@@ -650,7 +650,7 @@ class LayoutLMv2Tokenizer(PreTrainedTokenizer):
"""
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
manages a moving window (with user defined stride) for overflowing tokens
manages a moving window (with user defined stride) for overflowing tokens.
Args:
batch_ids_pairs: list of tokenized input ids or input ids pairs
......@@ -893,7 +893,9 @@ class LayoutLMv2Tokenizer(PreTrainedTokenizer):
"""
Prepares a sequence or a pair of sequences so that it can be used by the model. It adds special tokens,
truncates sequences if overflowing while taking into account the special tokens and manages a moving window
(with user defined stride) for overflowing tokens.
(with user defined stride) for overflowing tokens. Please Note, for `text_pair` different than `None` and
`truncation_strategy = longest_first` or `True`, it is not possible to return overflowing tokens. Such a
combination of arguments will raise an error.
Word-level :obj:`boxes` are turned into token-level :obj:`bbox`. If provided, word-level :obj:`word_labels` are
turned into token-level :obj:`labels`. The word label is used for the first token of the word, while remaining
......@@ -963,6 +965,17 @@ class LayoutLMv2Tokenizer(PreTrainedTokenizer):
ids = self.convert_tokens_to_ids(tokens)
pair_ids = self.convert_tokens_to_ids(pair_tokens) if pair_tokens else None
if (
return_overflowing_tokens
and truncation_strategy == TruncationStrategy.LONGEST_FIRST
and pair_ids is not None
):
raise ValueError(
"Not possible to return overflowing tokens for pair of sequences with the "
"`longest_first`. Please select another truncation strategy than `longest_first`, "
"for instance `only_second` or `only_first`."
)
# Compute the total size of the returned encodings
pair = bool(pair_ids is not None)
len_ids = len(ids)
......@@ -1114,7 +1127,8 @@ class LayoutLMv2Tokenizer(PreTrainedTokenizer):
Returns:
:obj:`Tuple[List[int], List[int], List[int]]`: The truncated ``ids``, the truncated ``pair_ids`` and the
list of overflowing tokens.
list of overflowing tokens. Note: The `longest_first` strategy returns empty list of overflowing tokens if
a pair of sequences (or a batch of pairs) is provided.
"""
if num_tokens_to_remove <= 0:
return ids, token_boxes, pair_ids, pair_token_boxes, labels, [], [], []
......@@ -1125,29 +1139,9 @@ class LayoutLMv2Tokenizer(PreTrainedTokenizer):
overflowing_tokens = []
overflowing_token_boxes = []
overflowing_labels = []
if truncation_strategy == TruncationStrategy.LONGEST_FIRST:
for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids):
if not overflowing_tokens:
window_len = min(len(ids), stride + 1)
else:
window_len = 1
overflowing_tokens.extend(ids[-window_len:])
overflowing_token_boxes.extend(token_boxes[-window_len:])
overflowing_labels.extend(labels[-window_len:])
ids = ids[:-1]
token_boxes = token_boxes[:-1]
labels = labels[:-1]
else:
if not overflowing_tokens:
window_len = min(len(pair_ids), stride + 1)
else:
window_len = 1
overflowing_tokens.extend(pair_ids[-window_len:])
overflowing_token_boxes.extend(pair_token_boxes[-window_len:])
pair_ids = pair_ids[:-1]
pair_token_boxes = pair_token_boxes[:-1]
elif truncation_strategy == TruncationStrategy.ONLY_FIRST:
if truncation_strategy == TruncationStrategy.ONLY_FIRST or (
truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None
):
if len(ids) > num_tokens_to_remove:
window_len = min(len(ids), stride + num_tokens_to_remove)
overflowing_tokens = ids[-window_len:]
......@@ -1157,12 +1151,31 @@ class LayoutLMv2Tokenizer(PreTrainedTokenizer):
token_boxes = token_boxes[:-num_tokens_to_remove]
labels = labels[:-num_tokens_to_remove]
else:
logger.error(
error_msg = (
f"We need to remove {num_tokens_to_remove} to truncate the input "
f"but the first sequence has a length {len(ids)}. "
f"Please select another truncation strategy than {truncation_strategy}, "
f"for instance 'longest_first' or 'only_second'."
)
if truncation_strategy == TruncationStrategy.ONLY_FIRST:
error_msg = (
error_msg + "Please select another truncation strategy than "
f"{truncation_strategy}, for instance 'longest_first' or 'only_second'."
)
logger.error(error_msg)
elif truncation_strategy == TruncationStrategy.LONGEST_FIRST:
logger.warning(
f"Be aware, overflowing tokens are not returned for the setting you have chosen,"
f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' "
f"truncation strategy. So the returned list will always be empty even if some "
f"tokens have been removed."
)
for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids):
ids = ids[:-1]
token_boxes = token_boxes[:-1]
labels = labels[:-1]
else:
pair_ids = pair_ids[:-1]
pair_token_boxes = pair_token_boxes[:-1]
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
if len(pair_ids) > num_tokens_to_remove:
window_len = min(len(pair_ids), stride + num_tokens_to_remove)
......
......@@ -3015,7 +3015,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
Returns:
:obj:`Tuple[List[int], List[int], List[int]]`: The truncated ``ids``, the truncated ``pair_ids`` and the
list of overflowing tokens. Note: The `longest_first` strategy returns empty list of overflowing_tokens if
list of overflowing tokens. Note: The `longest_first` strategy returns empty list of overflowing tokens if
a pair of sequences (or a batch of pairs) is provided.
"""
if num_tokens_to_remove <= 0:
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment