Unverified Commit 15478c12 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Batch encore plus and overflowing tokens fails when non existing overflowing...

Batch encore plus and overflowing tokens fails when non existing overflowing tokens for a sequence (#6677)

* Patch and test

* Fix tests
parent 9fd11bf1
...@@ -2440,6 +2440,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -2440,6 +2440,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
# Truncation: Handle max sequence length # Truncation: Handle max sequence length
overflowing_tokens = []
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
ids, pair_ids, overflowing_tokens = self.truncate_sequences( ids, pair_ids, overflowing_tokens = self.truncate_sequences(
ids, ids,
...@@ -2448,9 +2449,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -2448,9 +2449,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
truncation_strategy=truncation_strategy, truncation_strategy=truncation_strategy,
stride=stride, stride=stride,
) )
if return_overflowing_tokens:
encoded_inputs["overflowing_tokens"] = overflowing_tokens if return_overflowing_tokens:
encoded_inputs["num_truncated_tokens"] = total_len - max_length encoded_inputs["overflowing_tokens"] = overflowing_tokens
encoded_inputs["num_truncated_tokens"] = total_len - max_length
# Add special tokens # Add special tokens
if add_special_tokens: if add_special_tokens:
......
...@@ -1352,6 +1352,18 @@ class TokenizerTesterMixin: ...@@ -1352,6 +1352,18 @@ class TokenizerTesterMixin:
self.assertEqual(input_dict, prepared_input_dict) self.assertEqual(input_dict, prepared_input_dict)
def test_batch_encode_plus_overflowing_tokens(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
string_sequences = ["Testing the prepare_for_model method.", "Test"]
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer.batch_encode_plus(
string_sequences, return_overflowing_tokens=True, truncation=True, padding=True, max_length=3
)
@require_torch @require_torch
@require_tf @require_tf
def test_batch_encode_plus_tensors(self): def test_batch_encode_plus_tensors(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment