Unverified Commit c0eb218a authored by Hamel Husain's avatar Hamel Husain Committed by GitHub
Browse files

Update `PreTrainedTokenizerBase` to check/handle batch length for `text_pair` parameter (#11486)



* Update tokenization_utils_base.py

* add assertion

* check batch len

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add error message
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 2d27900b
...@@ -2279,6 +2279,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -2279,6 +2279,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
) )
if is_batched: if is_batched:
if isinstance(text_pair, str):
raise TypeError(
"when tokenizing batches of text, `text_pair` must be a list or tuple with the same length as `text`."
)
if text_pair is not None and len(text) != len(text_pair):
raise ValueError(
f"batch length of `text`: {len(text)} does not match batch length of `text_pair`: {len(text_pair)}."
)
batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
return self.batch_encode_plus( return self.batch_encode_plus(
batch_text_or_text_pairs=batch_text_or_text_pairs, batch_text_or_text_pairs=batch_text_or_text_pairs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment