Unverified Commit 293991d4 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Make `add_special_tokens` more clear (#20424)



* make add_special_tokens more clear
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent d0c1ded5
...@@ -841,7 +841,9 @@ class SpecialTokensMixin: ...@@ -841,7 +841,9 @@ class SpecialTokensMixin:
""" """
return self.add_tokens(self.all_special_tokens_extended, special_tokens=True) return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int: def add_special_tokens(
self, special_tokens_dict: Dict[str, Union[str, AddedToken]], replace_additional_special_tokens=True
) -> int:
""" """
Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If
special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the
...@@ -869,6 +871,11 @@ class SpecialTokensMixin: ...@@ -869,6 +871,11 @@ class SpecialTokensMixin:
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer
assign the index of the `unk_token` to them). assign the index of the `unk_token` to them).
replace_additional_special_tokens (`bool`, *optional*,, defaults to `True`):
If `True`, the existing list of additional special tokens will be replaced by the one specified in
`special_tokens_dict`. Otherwise, `self._additional_special_tokens` is updated. In the former case, the
tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged as
non-special tokens.
Returns: Returns:
`int`: Number of tokens added to the vocabulary. `int`: Number of tokens added to the vocabulary.
...@@ -898,17 +905,32 @@ class SpecialTokensMixin: ...@@ -898,17 +905,32 @@ class SpecialTokensMixin:
if self.verbose: if self.verbose:
logger.info(f"Assigning {value} to the {key} key of the tokenizer") logger.info(f"Assigning {value} to the {key} key of the tokenizer")
setattr(self, key, value)
if key == "additional_special_tokens": if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all( assert isinstance(value, (list, tuple)) and all(
isinstance(t, (str, AddedToken)) for t in value isinstance(t, (str, AddedToken)) for t in value
), f"Tokens {value} for key {key} should all be str or AddedToken instances" ), f"Tokens {value} for key {key} should all be str or AddedToken instances"
if replace_additional_special_tokens:
setattr(self, key, value)
else:
# This is a copy of `self._additional_special_tokens`
additional_special_tokens = getattr(self, key)
additional_special_tokens_set = set(additional_special_tokens)
to_add = []
for token in value:
if str(token) not in additional_special_tokens_set and str(token) not in to_add:
to_add.append(token)
# update the property
additional_special_tokens.extend(to_add)
self.additional_special_tokens = additional_special_tokens
added_tokens += self.add_tokens(value, special_tokens=True) added_tokens += self.add_tokens(value, special_tokens=True)
else: else:
assert isinstance( assert isinstance(
value, (str, AddedToken) value, (str, AddedToken)
), f"Token {value} for key {key} should be a str or an AddedToken instance" ), f"Token {value} for key {key} should be a str or an AddedToken instance"
setattr(self, key, value)
added_tokens += self.add_tokens([value], special_tokens=True) added_tokens += self.add_tokens([value], special_tokens=True)
return added_tokens return added_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment