Unverified Commit 6ee1474b authored by Vít Novotný's avatar Vít Novotný Committed by GitHub
Browse files

Accumulate tokens into batches in `PreTrainedTokenizerBase.add_tokens()` (#17119)

* Accumulate tokens into batches in PreTrainedTokenizerBase.add_tokens()

For tokenizers with a small number of special tokens or special tokens
with consecutive token IDs, this reduces the time complexity of creating
the trie from quadratic to linear, see also #16936.

* Extend explanation of batching added tokens
parent 52e7c929
...@@ -1964,24 +1964,38 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1964,24 +1964,38 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# Sort added tokens by index # Sort added tokens by index
added_tok_encoder_sorted = list(sorted(added_tok_encoder.items(), key=lambda x: x[1])) added_tok_encoder_sorted = list(sorted(added_tok_encoder.items(), key=lambda x: x[1]))
# Accumulate added tokens into batches of special/non-special tokens, because calling add_tokens() for
# individual tokens would repeatedly rebuild a trie, which can be slow.
is_last_special = None
tokens = []
for token, index in added_tok_encoder_sorted: for token, index in added_tok_encoder_sorted:
if has_tokenizer_file and index != len(tokenizer) and tokenizer.convert_tokens_to_ids(token) != index: current_index = len(tokenizer) + len(tokens)
if has_tokenizer_file and index != current_index and tokenizer.convert_tokens_to_ids(token) != index:
# Tokenizer fast: added token needs to either be in the vocabulary with the proper index or the # Tokenizer fast: added token needs to either be in the vocabulary with the proper index or the
# index is the current length of the tokenizer (not in vocabulary) # index is the current length of the tokenizer (not in vocabulary)
raise ValueError( raise ValueError(
f"Wrong index found for {token}: should be {tokenizer.convert_tokens_to_ids(token)} but found " f"Wrong index found for {token}: should be {tokenizer.convert_tokens_to_ids(token)} but found "
f"{index}." f"{index}."
) )
elif not has_tokenizer_file and index != len(tokenizer): elif not has_tokenizer_file and index != current_index:
# Tokenizer slow: added token cannot already be in the vocabulary so its index needs to be the # Tokenizer slow: added token cannot already be in the vocabulary so its index needs to be the
# current length of the tokenizer. # current length of the tokenizer.
raise ValueError( raise ValueError(
f"Non-consecutive added token '{token}' found. " f"Non-consecutive added token '{token}' found. "
f"Should have index {len(tokenizer)} but has index {index} in saved vocabulary." f"Should have index {current_index} but has index {index} in saved vocabulary."
) )
# Safe to call on a tokenizer fast even if token already there. is_special = bool(token in special_tokens)
tokenizer.add_tokens(token, special_tokens=bool(token in special_tokens)) if is_last_special is None or is_last_special == is_special:
tokens.append(token)
else:
tokenizer.add_tokens(tokens, special_tokens=is_last_special)
tokens = [token]
is_last_special = is_special
if tokens:
tokenizer.add_tokens(tokens, special_tokens=is_last_special)
# Check all our special tokens are registered as "no split" token (we don't cut them) and are in the vocab # Check all our special tokens are registered as "no split" token (we don't cut them) and are in the vocab
added_tokens = tokenizer.sanitize_special_tokens() added_tokens = tokenizer.sanitize_special_tokens()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment