Unverified Commit 601d4d69 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

[tokenizers] Updates data processors, docstring, examples and model cards to the new API (#5308)

* remove references to old API in docstring - update data processors

* style

* fix tests - better type checking error messages

* better type checking

* include awesome fix by @LysandreJik for #5310

* updated doc and examples
parent fd405e9a
......@@ -125,7 +125,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
return self.sp_model.IdToPiece(index - self.fairseq_offset)
def set_lang(self, lang: str) -> None:
"""Set the current language code in order to call batch_encode_plus properly."""
"""Set the current language code in order to call tokenizer properly."""
self.cur_lang_code = self.lang_code_to_id[lang]
def prepare_translation_batch(
......
......@@ -263,7 +263,7 @@ class BertTokenizer(PreTrainedTokenizer):
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
special tokens using the tokenizer ``prepare_for_model`` method.
Args:
token_ids_0 (:obj:`List[int]`):
......
......@@ -171,7 +171,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
special tokens using the tokenizer ``prepare_for_model`` method.
Args:
token_ids_0 (:obj:`List[int]`):
......
......@@ -193,7 +193,7 @@ class RobertaTokenizer(GPT2Tokenizer):
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
special tokens using the tokenizer ``prepare_for_model`` method.
Args:
token_ids_0 (:obj:`List[int]`):
......
......@@ -820,7 +820,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
special tokens using the tokenizer ``prepare_for_model`` method.
Args:
token_ids_0: list of ids (must not contain special tokens)
......
......@@ -1583,6 +1583,42 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
If the sequences are provided as list of strings (pretokenized), you must set `is_pretokenized=True`
(to lift the ambiguity with a batch of sequences)
"""
# Input type checking for clearer error
assert isinstance(text, str) or (
isinstance(text, (list, tuple))
and (
len(text) == 0
or (
isinstance(text[0], str)
or (isinstance(text[0], (list, tuple)) and (len(text[0]) == 0 or isinstance(text[0][0], str)))
)
)
), (
"text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) "
"or `List[List[str]]` (batch of pretokenized examples)."
)
assert (
text_pair is None
or isinstance(text_pair, str)
or (
isinstance(text_pair, (list, tuple))
and (
len(text_pair) == 0
or (
isinstance(text_pair[0], str)
or (
isinstance(text_pair[0], (list, tuple))
and (len(text_pair[0]) == 0 or isinstance(text_pair[0][0], str))
)
)
)
)
), (
"text_pair input must of type `str` (single example), `List[str]` (batch or single pretokenized example) "
"or `List[List[str]]` (batch of pretokenized examples)."
)
is_batched = bool(
(not is_pretokenized and isinstance(text, (list, tuple)))
or (is_pretokenized and isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)))
......
......@@ -882,7 +882,7 @@ class XLMTokenizer(PreTrainedTokenizer):
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
special tokens using the tokenizer ``prepare_for_model`` methods.
Args:
token_ids_0 (:obj:`List[int]`):
......
......@@ -206,7 +206,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
special tokens using the tokenizer ``prepare_for_model`` methods.
Args:
token_ids_0 (:obj:`List[int]`):
......
......@@ -267,7 +267,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
special tokens using the tokenizer ``prepare_for_model`` methods.
Args:
token_ids_0 (:obj:`List[int]`):
......
......@@ -171,7 +171,7 @@ class XxxTokenizer(PreTrainedTokenizer):
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
special tokens using the tokenizer ``prepare_for_model`` methods.
Args:
token_ids_0: list of ids (must not contain special tokens)
......
......@@ -626,9 +626,9 @@ class BartModelIntegrationTests(unittest.TestCase):
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state."
dct = tok.batch_encode_plus([PGE_ARTICLE], max_length=1024, pad_to_max_length=True, return_tensors="pt",).to(
torch_device
)
dct = tok.batch_encode_plus(
[PGE_ARTICLE], max_length=1024, padding="max_length", truncation=True, return_tensors="pt",
).to(torch_device)
hypotheses_batch = model.generate(
input_ids=dct["input_ids"],
......@@ -672,7 +672,8 @@ class BartModelIntegrationTests(unittest.TestCase):
dct = tok.batch_encode_plus(
[FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
max_length=1024,
pad_to_max_length=True,
padding="max_length",
truncation=True,
return_tensors="pt",
)
......
......@@ -375,10 +375,11 @@ class T5ModelIntegrationTests(unittest.TestCase):
summarization_config = task_specific_config.get("summarization", {})
model.config.update(summarization_config)
dct = tok.batch_encode_plus(
dct = tok(
[model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]],
max_length=512,
pad_to_max_length=True,
padding="max_length",
truncation=True,
return_tensors="pt",
)
self.assertEqual(512, dct["input_ids"].shape[1])
......
......@@ -276,10 +276,11 @@ class TFT5ModelIntegrationTests(unittest.TestCase):
summarization_config = task_specific_config.get("summarization", {})
model.config.update(summarization_config)
dct = tok.batch_encode_plus(
dct = tok(
[model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]],
max_length=512,
pad_to_max_length=True,
padding="max_length",
truncation=True,
return_tensors="tf",
)
self.assertEqual(512, dct["input_ids"].shape[1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment