Unverified Commit 15478c12 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Batch encore plus and overflowing tokens fails when non existing overflowing...

Batch encore plus and overflowing tokens fails when non existing overflowing tokens for a sequence (#6677)

* Patch and test

* Fix tests
parent 9fd11bf1
......@@ -2440,6 +2440,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
# Truncation: Handle max sequence length
overflowing_tokens = []
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
ids,
......@@ -2448,9 +2449,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
truncation_strategy=truncation_strategy,
stride=stride,
)
if return_overflowing_tokens:
encoded_inputs["overflowing_tokens"] = overflowing_tokens
encoded_inputs["num_truncated_tokens"] = total_len - max_length
if return_overflowing_tokens:
encoded_inputs["overflowing_tokens"] = overflowing_tokens
encoded_inputs["num_truncated_tokens"] = total_len - max_length
# Add special tokens
if add_special_tokens:
......
......@@ -1352,6 +1352,18 @@ class TokenizerTesterMixin:
self.assertEqual(input_dict, prepared_input_dict)
def test_batch_encode_plus_overflowing_tokens(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
string_sequences = ["Testing the prepare_for_model method.", "Test"]
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer.batch_encode_plus(
string_sequences, return_overflowing_tokens=True, truncation=True, padding=True, max_length=3
)
@require_torch
@require_tf
def test_batch_encode_plus_tensors(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment