Unverified Commit a8c3f9aa authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Warning about too long input for fast tokenizers too (#8799)

* Warning about too long input for fast tokenizers too

If truncation is not set in tokenizers, but the tokenization is too long
for the model (`model_max_length`), we used to trigger a warning that

The input would probably fail (which it most likely will).

This PR re-enables the warning for fast tokenizers too and uses common
code for the trigger to make sure it's consistent across.

* Checking for pair of inputs too.

* Making the function private and adding it's doc.

* Remove formatting ?? in odd place.

* Missed uppercase.
parent f6b44e61
...@@ -2866,14 +2866,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -2866,14 +2866,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
encoded_inputs["special_tokens_mask"] = [0] * len(sequence) encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
# Check lengths # Check lengths
if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length and verbose: self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False):
logger.warning(
"Token indices sequence length is longer than the specified maximum sequence length "
"for this model ({} > {}). Running this sequence through the model will result in "
"indexing errors".format(len(encoded_inputs["input_ids"]), self.model_max_length)
)
self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
# Padding # Padding
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
...@@ -3204,3 +3197,23 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -3204,3 +3197,23 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
.replace(" 're", "'re") .replace(" 're", "'re")
) )
return out_string return out_string
def _eventual_warn_about_too_long_sequence(self, ids: List[int], max_length: Optional[int], verbose: bool):
"""
Depending on the input and internal state we might trigger a warning about a sequence that is too long for it's
corresponding model
Args:
ids (:obj:`List[str]`): The ids produced by the tokenization
max_length (:obj:`int`, `optional`): The max_length desired (does not trigger a warning if it is set)
verbose (:obj:`bool`): Whether or not to print more information and warnings.
"""
if max_length is None and len(ids) > self.model_max_length and verbose:
if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False):
logger.warning(
"Token indices sequence length is longer than the specified maximum sequence length "
"for this model ({} > {}). Running this sequence through the model will result in "
"indexing errors".format(len(ids), self.model_max_length)
)
self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
...@@ -418,6 +418,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -418,6 +418,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
overflow_to_sample_mapping += [i] * len(toks["input_ids"]) overflow_to_sample_mapping += [i] * len(toks["input_ids"])
sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
for input_ids in sanitized_tokens["input_ids"]:
self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors) return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors)
def _encode_plus( def _encode_plus(
...@@ -474,6 +476,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -474,6 +476,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
batched_output.encodings, batched_output.encodings,
) )
self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose)
return batched_output return batched_output
def convert_tokens_to_string(self, tokens: List[str]) -> str: def convert_tokens_to_string(self, tokens: List[str]) -> str:
......
...@@ -666,11 +666,28 @@ class TokenizerTesterMixin: ...@@ -666,11 +666,28 @@ class TokenizerTesterMixin:
self.assertEqual(len(output["input_ids"][0]), model_max_length) self.assertEqual(len(output["input_ids"][0]), model_max_length)
# Simple with no truncation # Simple with no truncation
output = tokenizer(seq_1, padding=padding_state, truncation=False) # Reset warnings
self.assertNotEqual(len(output["input_ids"]), model_max_length) tokenizer.deprecation_warnings = {}
with self.assertLogs("transformers", level="WARNING") as cm:
output = tokenizer(seq_1, padding=padding_state, truncation=False)
self.assertNotEqual(len(output["input_ids"]), model_max_length)
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
"Token indices sequence length is longer than the specified maximum sequence length for this model"
)
)
output = tokenizer([seq_1], padding=padding_state, truncation=False) tokenizer.deprecation_warnings = {}
self.assertNotEqual(len(output["input_ids"][0]), model_max_length) with self.assertLogs("transformers", level="WARNING") as cm:
output = tokenizer([seq_1], padding=padding_state, truncation=False)
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
"Token indices sequence length is longer than the specified maximum sequence length for this model"
)
)
# Overflowing tokens # Overflowing tokens
stride = 2 stride = 2
...@@ -770,11 +787,28 @@ class TokenizerTesterMixin: ...@@ -770,11 +787,28 @@ class TokenizerTesterMixin:
self.assertEqual(len(output["input_ids"][0]), model_max_length) self.assertEqual(len(output["input_ids"][0]), model_max_length)
# Simple with no truncation # Simple with no truncation
output = tokenizer(seq_1, seq_2, padding=padding_state, truncation=False) # Reset warnings
self.assertNotEqual(len(output["input_ids"]), model_max_length) tokenizer.deprecation_warnings = {}
with self.assertLogs("transformers", level="WARNING") as cm:
output = tokenizer(seq_1, seq_2, padding=padding_state, truncation=False)
self.assertNotEqual(len(output["input_ids"]), model_max_length)
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
"Token indices sequence length is longer than the specified maximum sequence length for this model"
)
)
output = tokenizer([seq_1], [seq_2], padding=padding_state, truncation=False) tokenizer.deprecation_warnings = {}
self.assertNotEqual(len(output["input_ids"][0]), model_max_length) with self.assertLogs("transformers", level="WARNING") as cm:
output = tokenizer([seq_1], [seq_2], padding=padding_state, truncation=False)
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
"Token indices sequence length is longer than the specified maximum sequence length for this model"
)
)
truncated_first_sequence = tokenizer.encode(seq_0, add_special_tokens=False)[:-2] + tokenizer.encode( truncated_first_sequence = tokenizer.encode(seq_0, add_special_tokens=False)[:-2] + tokenizer.encode(
seq_1, add_special_tokens=False seq_1, add_special_tokens=False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment