"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7183cded4e10fd35ee1737d738c4a657db28ee1f"
Unverified Commit b4727a12 authored by Omar Salman's avatar Omar Salman Committed by GitHub
Browse files

Fix conflicting key in init kwargs in PreTrainedTokenizerBase (#31233)

* Fix conflicting key in init kwargs in PreTrainedTokenizerBase

* Update code to check for callable key in save_pretrained

* Apply PR suggestions

* Invoke CI

* Updates based on PR suggestion
parent db8c7cae
...@@ -1569,6 +1569,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1569,6 +1569,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
def __init__(self, **kwargs): def __init__(self, **kwargs):
# inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
self.init_inputs = () self.init_inputs = ()
for key in kwargs:
if hasattr(self, key) and callable(getattr(self, key)):
raise AttributeError(f"{key} conflicts with the method {key} in {self.__class__.__name__}")
self.init_kwargs = copy.deepcopy(kwargs) self.init_kwargs = copy.deepcopy(kwargs)
self.name_or_path = kwargs.pop("name_or_path", "") self.name_or_path = kwargs.pop("name_or_path", "")
self._processor_class = kwargs.pop("processor_class", None) self._processor_class = kwargs.pop("processor_class", None)
......
...@@ -4408,3 +4408,11 @@ class TokenizerTesterMixin: ...@@ -4408,3 +4408,11 @@ class TokenizerTesterMixin:
replace_additional_special_tokens=False, replace_additional_special_tokens=False,
) )
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>", "<tok>"]) self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>", "<tok>"])
def test_tokenizer_initialization_with_conflicting_key(self):
get_tokenizer_func = self.get_rust_tokenizer if self.test_rust_tokenizer else self.get_tokenizer
with self.assertRaises(AttributeError, msg="conflicts with the method"):
get_tokenizer_func(add_special_tokens=True)
with self.assertRaises(AttributeError, msg="conflicts with the method"):
get_tokenizer_func(get_vocab=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment