Unverified Commit 6ce6d62b authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Explicit arguments in `from_pretrained` (#24306)



* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 127e81c2
......@@ -118,6 +118,8 @@ class XCLIPTextConfig(PretrainedConfig):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the text config dict if we are loading from XCLIPConfig
......@@ -243,6 +245,8 @@ class XCLIPVisionConfig(PretrainedConfig):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from XCLIPConfig
......
......@@ -17,7 +17,9 @@
"""
import os
import warnings
from pathlib import Path
from typing import Optional, Union
from .dynamic_module_utils import custom_object_save
from .tokenization_utils_base import PreTrainedTokenizerBase
......@@ -151,7 +153,16 @@ class ProcessorMixin(PushToHubMixin):
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
):
r"""
Instantiate a processor associated with a pretrained model.
......@@ -181,6 +192,26 @@ class ProcessorMixin(PushToHubMixin):
[`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] and
[`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`].
"""
kwargs["cache_dir"] = cache_dir
kwargs["force_download"] = force_download
kwargs["local_files_only"] = local_files_only
kwargs["revision"] = revision
use_auth_token = kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
if token is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
token = use_auth_token
if token is not None:
# change to `token` in a follow-up PR
kwargs["use_auth_token"] = token
args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(*args)
......
......@@ -1615,7 +1615,17 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
raise NotImplementedError()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs):
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
*init_inputs,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
):
r"""
Instantiate a [`~tokenization_utils_base.PreTrainedTokenizerBase`] (or a derived class) from a predefined
tokenizer.
......@@ -1645,7 +1655,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
use_auth_token (`str` or *bool*, *optional*):
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
local_files_only (`bool`, *optional*, defaults to `False`):
......@@ -1692,18 +1702,29 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# Otherwise use tokenizer.add_special_tokens({'unk_token': '<unk>'}) instead)
assert tokenizer.unk_token == "<unk>"
```"""
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
commit_hash = kwargs.pop("_commit_hash", None)
use_auth_token = kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
if token is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
token = use_auth_token
if token is not None:
# change to `token` in a follow-up PR
kwargs["use_auth_token"] = token
user_agent = {"file_type": "tokenizer", "from_auto_class": from_auto_class, "is_fast": "Fast" in cls.__name__}
if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline
......@@ -1752,7 +1773,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
use_auth_token=use_auth_token,
use_auth_token=token,
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
......@@ -1789,7 +1810,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
use_auth_token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
......@@ -1827,7 +1848,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
pretrained_model_name_or_path,
init_configuration,
*init_inputs,
use_auth_token=use_auth_token,
use_auth_token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
_commit_hash=commit_hash,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment