Unverified Commit 6ce6d62b authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Explicit arguments in `from_pretrained` (#24306)



* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 127e81c2
...@@ -467,7 +467,43 @@ class PretrainedConfig(PushToHubMixin): ...@@ -467,7 +467,43 @@ class PretrainedConfig(PushToHubMixin):
) )
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def _set_token_in_kwargs(self, kwargs, token=None):
"""Temporary method to deal with `token` and `use_auth_token`.
This method is to avoid apply the same changes in all model config classes that overwrite `from_pretrained`.
Need to clean up `use_auth_token` in a follow PR.
"""
# Some model config classes like CLIP define their own `from_pretrained` without the new argument `token` yet.
if token is None:
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
if token is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
token = use_auth_token
if token is not None:
# change to `token` in a follow-up PR
kwargs["use_auth_token"] = token
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
) -> "PretrainedConfig":
r""" r"""
Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration. Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration.
...@@ -493,7 +529,7 @@ class PretrainedConfig(PushToHubMixin): ...@@ -493,7 +529,7 @@ class PretrainedConfig(PushToHubMixin):
proxies (`Dict[str, str]`, *optional*): proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or `bool`, *optional*): token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -544,6 +580,13 @@ class PretrainedConfig(PushToHubMixin): ...@@ -544,6 +580,13 @@ class PretrainedConfig(PushToHubMixin):
assert config.output_attentions == True assert config.output_attentions == True
assert unused_kwargs == {"foo": False} assert unused_kwargs == {"foo": False}
```""" ```"""
kwargs["cache_dir"] = cache_dir
kwargs["force_download"] = force_download
kwargs["local_files_only"] = local_files_only
kwargs["revision"] = revision
cls._set_token_in_kwargs(kwargs, token)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning( logger.warning(
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
import copy import copy
import json import json
import os import os
import warnings
from collections import UserDict from collections import UserDict
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
...@@ -255,8 +256,15 @@ class FeatureExtractionMixin(PushToHubMixin): ...@@ -255,8 +256,15 @@ class FeatureExtractionMixin(PushToHubMixin):
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs cls,
) -> PreTrainedFeatureExtractor: pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
):
r""" r"""
Instantiate a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a feature extractor, *e.g.* a Instantiate a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a feature extractor, *e.g.* a
derived class of [`SequenceFeatureExtractor`]. derived class of [`SequenceFeatureExtractor`].
...@@ -285,7 +293,7 @@ class FeatureExtractionMixin(PushToHubMixin): ...@@ -285,7 +293,7 @@ class FeatureExtractionMixin(PushToHubMixin):
proxies (`Dict[str, str]`, *optional*): proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or `bool`, *optional*): token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -335,6 +343,26 @@ class FeatureExtractionMixin(PushToHubMixin): ...@@ -335,6 +343,26 @@ class FeatureExtractionMixin(PushToHubMixin):
assert feature_extractor.return_attention_mask is False assert feature_extractor.return_attention_mask is False
assert unused_kwargs == {"foo": False} assert unused_kwargs == {"foo": False}
```""" ```"""
kwargs["cache_dir"] = cache_dir
kwargs["force_download"] = force_download
kwargs["local_files_only"] = local_files_only
kwargs["revision"] = revision
use_auth_token = kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
if token is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
token = use_auth_token
if token is not None:
# change to `token` in a follow-up PR
kwargs["use_auth_token"] = token
feature_extractor_dict, kwargs = cls.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) feature_extractor_dict, kwargs = cls.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
return cls.from_dict(feature_extractor_dict, **kwargs) return cls.from_dict(feature_extractor_dict, **kwargs)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import copy import copy
import json import json
import os import os
import warnings
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
from .. import __version__ from .. import __version__
...@@ -382,6 +383,11 @@ class GenerationConfig(PushToHubMixin): ...@@ -382,6 +383,11 @@ class GenerationConfig(PushToHubMixin):
cls, cls,
pretrained_model_name: Union[str, os.PathLike], pretrained_model_name: Union[str, os.PathLike],
config_file_name: Optional[Union[str, os.PathLike]] = None, config_file_name: Optional[Union[str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs, **kwargs,
) -> "GenerationConfig": ) -> "GenerationConfig":
r""" r"""
...@@ -410,7 +416,7 @@ class GenerationConfig(PushToHubMixin): ...@@ -410,7 +416,7 @@ class GenerationConfig(PushToHubMixin):
proxies (`Dict[str, str]`, *optional*): proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or `bool`, *optional*): token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -470,18 +476,24 @@ class GenerationConfig(PushToHubMixin): ...@@ -470,18 +476,24 @@ class GenerationConfig(PushToHubMixin):
```""" ```"""
config_file_name = config_file_name if config_file_name is not None else GENERATION_CONFIG_NAME config_file_name = config_file_name if config_file_name is not None else GENERATION_CONFIG_NAME
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", "") subfolder = kwargs.pop("subfolder", "")
from_pipeline = kwargs.pop("_from_pipeline", None) from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
commit_hash = kwargs.pop("_commit_hash", None) commit_hash = kwargs.pop("_commit_hash", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
if token is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
token = use_auth_token
user_agent = {"file_type": "config", "from_auto_class": from_auto_class} user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
if from_pipeline is not None: if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline user_agent["using_pipeline"] = from_pipeline
...@@ -509,7 +521,7 @@ class GenerationConfig(PushToHubMixin): ...@@ -509,7 +521,7 @@ class GenerationConfig(PushToHubMixin):
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=token,
user_agent=user_agent, user_agent=user_agent,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import copy import copy
import json import json
import os import os
import warnings
from typing import Any, Dict, Iterable, Optional, Tuple, Union from typing import Any, Dict, Iterable, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -81,7 +82,16 @@ class ImageProcessingMixin(PushToHubMixin): ...@@ -81,7 +82,16 @@ class ImageProcessingMixin(PushToHubMixin):
self._processor_class = processor_class self._processor_class = processor_class
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
):
r""" r"""
Instantiate a type of [`~image_processing_utils.ImageProcessingMixin`] from an image processor. Instantiate a type of [`~image_processing_utils.ImageProcessingMixin`] from an image processor.
...@@ -109,7 +119,7 @@ class ImageProcessingMixin(PushToHubMixin): ...@@ -109,7 +119,7 @@ class ImageProcessingMixin(PushToHubMixin):
proxies (`Dict[str, str]`, *optional*): proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or `bool`, *optional*): token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -162,6 +172,26 @@ class ImageProcessingMixin(PushToHubMixin): ...@@ -162,6 +172,26 @@ class ImageProcessingMixin(PushToHubMixin):
assert image_processor.do_normalize is False assert image_processor.do_normalize is False
assert unused_kwargs == {"foo": False} assert unused_kwargs == {"foo": False}
```""" ```"""
kwargs["cache_dir"] = cache_dir
kwargs["force_download"] = force_download
kwargs["local_files_only"] = local_files_only
kwargs["revision"] = revision
use_auth_token = kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
if token is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
token = use_auth_token
if token is not None:
# change to `token` in a follow-up PR
kwargs["use_auth_token"] = token
image_processor_dict, kwargs = cls.get_image_processor_dict(pretrained_model_name_or_path, **kwargs) image_processor_dict, kwargs = cls.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
return cls.from_dict(image_processor_dict, **kwargs) return cls.from_dict(image_processor_dict, **kwargs)
......
...@@ -18,9 +18,10 @@ import gc ...@@ -18,9 +18,10 @@ import gc
import json import json
import os import os
import re import re
import warnings
from functools import partial from functools import partial
from pickle import UnpicklingError from pickle import UnpicklingError
from typing import Any, Dict, Set, Tuple, Union from typing import Any, Dict, Optional, Set, Tuple, Union
import flax.linen as nn import flax.linen as nn
import jax import jax
...@@ -485,6 +486,13 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -485,6 +486,13 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
pretrained_model_name_or_path: Union[str, os.PathLike], pretrained_model_name_or_path: Union[str, os.PathLike],
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
*model_args, *model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs, **kwargs,
): ):
r""" r"""
...@@ -558,7 +566,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -558,7 +566,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only(`bool`, *optional*, defaults to `False`): local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model). Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or `bool`, *optional*): token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -603,16 +611,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -603,16 +611,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
>>> config = BertConfig.from_json_file("./pt_model/config.json") >>> config = BertConfig.from_json_file("./pt_model/config.json")
>>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config) >>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config)
```""" ```"""
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
from_pt = kwargs.pop("from_pt", False) from_pt = kwargs.pop("from_pt", False)
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
trust_remote_code = kwargs.pop("trust_remote_code", None) trust_remote_code = kwargs.pop("trust_remote_code", None)
from_pipeline = kwargs.pop("_from_pipeline", None) from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
...@@ -620,6 +622,16 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -620,6 +622,16 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
subfolder = kwargs.pop("subfolder", "") subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None) commit_hash = kwargs.pop("_commit_hash", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
if token is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
token = use_auth_token
if trust_remote_code is True: if trust_remote_code is True:
logger.warning( logger.warning(
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
...@@ -645,7 +657,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -645,7 +657,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
_from_auto=from_auto_class, _from_auto=from_auto_class,
...@@ -715,7 +727,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -715,7 +727,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
"proxies": proxies, "proxies": proxies,
"resume_download": resume_download, "resume_download": resume_download,
"local_files_only": local_files_only, "local_files_only": local_files_only,
"use_auth_token": use_auth_token, "use_auth_token": token,
"user_agent": user_agent, "user_agent": user_agent,
"revision": revision, "revision": revision,
"subfolder": subfolder, "subfolder": subfolder,
...@@ -746,7 +758,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -746,7 +758,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
has_file_kwargs = { has_file_kwargs = {
"revision": revision, "revision": revision,
"proxies": proxies, "proxies": proxies,
"use_auth_token": use_auth_token, "use_auth_token": token,
} }
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError( raise EnvironmentError(
...@@ -797,7 +809,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -797,7 +809,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=token,
user_agent=user_agent, user_agent=user_agent,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
...@@ -986,7 +998,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -986,7 +998,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
_from_auto=from_auto_class, _from_auto=from_auto_class,
......
...@@ -2507,7 +2507,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2507,7 +2507,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
) )
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
):
r""" r"""
Instantiate a pretrained TF 2.0 model from a pre-trained model configuration. Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
...@@ -2573,7 +2585,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2573,7 +2585,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
dictionary containing missing keys, unexpected keys and error messages. dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`): local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (e.g., not try doanloading the model). Whether or not to only look at local files (e.g., not try doanloading the model).
use_auth_token (`str` or `bool`, *optional*): token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -2629,17 +2641,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2629,17 +2641,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
>>> config = BertConfig.from_json_file("./pt_model/my_pt_model_config.json") >>> config = BertConfig.from_json_file("./pt_model/my_pt_model_config.json")
>>> model = TFBertModel.from_pretrained("./pt_model/my_pytorch_model.bin", from_pt=True, config=config) >>> model = TFBertModel.from_pretrained("./pt_model/my_pytorch_model.bin", from_pt=True, config=config)
```""" ```"""
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
from_pt = kwargs.pop("from_pt", False) from_pt = kwargs.pop("from_pt", False)
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False) output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
trust_remote_code = kwargs.pop("trust_remote_code", None) trust_remote_code = kwargs.pop("trust_remote_code", None)
_ = kwargs.pop("mirror", None) _ = kwargs.pop("mirror", None)
load_weight_prefix = kwargs.pop("load_weight_prefix", None) load_weight_prefix = kwargs.pop("load_weight_prefix", None)
...@@ -2649,6 +2655,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2649,6 +2655,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
commit_hash = kwargs.pop("_commit_hash", None) commit_hash = kwargs.pop("_commit_hash", None)
tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None) tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
if token is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
token = use_auth_token
if trust_remote_code is True: if trust_remote_code is True:
logger.warning( logger.warning(
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
...@@ -2674,7 +2690,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2674,7 +2690,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=use_auth_token,
revision=revision, revision=revision,
_from_auto=from_auto_class, _from_auto=from_auto_class,
_from_pipeline=from_pipeline, _from_pipeline=from_pipeline,
...@@ -2761,7 +2777,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2761,7 +2777,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"proxies": proxies, "proxies": proxies,
"resume_download": resume_download, "resume_download": resume_download,
"local_files_only": local_files_only, "local_files_only": local_files_only,
"use_auth_token": use_auth_token, "use_auth_token": token,
"user_agent": user_agent, "user_agent": user_agent,
"revision": revision, "revision": revision,
"subfolder": subfolder, "subfolder": subfolder,
...@@ -2808,7 +2824,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2808,7 +2824,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
has_file_kwargs = { has_file_kwargs = {
"revision": revision, "revision": revision,
"proxies": proxies, "proxies": proxies,
"use_auth_token": use_auth_token, "use_auth_token": token,
} }
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError( raise EnvironmentError(
...@@ -2855,7 +2871,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2855,7 +2871,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=token,
user_agent=user_agent, user_agent=user_agent,
revision=revision, revision=revision,
_commit_hash=commit_hash, _commit_hash=commit_hash,
...@@ -3011,7 +3027,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -3011,7 +3027,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
_from_auto=from_auto_class, _from_auto=from_auto_class,
......
...@@ -1914,7 +1914,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1914,7 +1914,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return super().float(*args) return super().float(*args)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
**kwargs,
):
r""" r"""
Instantiate a pretrained pytorch model from a pre-trained model configuration. Instantiate a pretrained pytorch model from a pre-trained model configuration.
...@@ -1995,7 +2008,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1995,7 +2008,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`): local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model). Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or `bool`, *optional*): token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -2085,6 +2098,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2085,6 +2098,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
variant (`str`, *optional*): variant (`str`, *optional*):
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
ignored when using `from_tf` or `from_flax`. ignored when using `from_tf` or `from_flax`.
use_safetensors (`bool`, *optional*, defaults to `None`):
Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
is not installed, it will be set to `False`.
kwargs (remaining dictionary of keyword arguments, *optional*): kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
...@@ -2142,19 +2158,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2142,19 +2158,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Currently, it can't handle deepspeed ZeRO stage 3 and ignores loading errors Currently, it can't handle deepspeed ZeRO stage 3 and ignores loading errors
""" """
config = kwargs.pop("config", None)
state_dict = kwargs.pop("state_dict", None) state_dict = kwargs.pop("state_dict", None)
cache_dir = kwargs.pop("cache_dir", None)
from_tf = kwargs.pop("from_tf", False) from_tf = kwargs.pop("from_tf", False)
from_flax = kwargs.pop("from_flax", False) from_flax = kwargs.pop("from_flax", False)
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False) output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
trust_remote_code = kwargs.pop("trust_remote_code", None) trust_remote_code = kwargs.pop("trust_remote_code", None)
_ = kwargs.pop("mirror", None) _ = kwargs.pop("mirror", None)
from_pipeline = kwargs.pop("_from_pipeline", None) from_pipeline = kwargs.pop("_from_pipeline", None)
...@@ -2172,7 +2182,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2172,7 +2182,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
subfolder = kwargs.pop("subfolder", "") subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None) commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None) variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
if token is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
token = use_auth_token
if use_safetensors is None and not is_safetensors_available():
use_safetensors = False
if is_bitsandbytes_available(): if is_bitsandbytes_available():
is_8bit_serializable = version.parse(importlib_metadata.version("bitsandbytes")) > version.parse("0.37.2") is_8bit_serializable = version.parse(importlib_metadata.version("bitsandbytes")) > version.parse("0.37.2")
...@@ -2302,7 +2324,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2302,7 +2324,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
_from_auto=from_auto_class, _from_auto=from_auto_class,
...@@ -2481,7 +2503,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2481,7 +2503,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"proxies": proxies, "proxies": proxies,
"resume_download": resume_download, "resume_download": resume_download,
"local_files_only": local_files_only, "local_files_only": local_files_only,
"use_auth_token": use_auth_token, "use_auth_token": token,
"user_agent": user_agent, "user_agent": user_agent,
"revision": revision, "revision": revision,
"subfolder": subfolder, "subfolder": subfolder,
...@@ -2526,7 +2548,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2526,7 +2548,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
has_file_kwargs = { has_file_kwargs = {
"revision": revision, "revision": revision,
"proxies": proxies, "proxies": proxies,
"use_auth_token": use_auth_token, "use_auth_token": token,
} }
if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs): if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError( raise EnvironmentError(
...@@ -2587,7 +2609,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2587,7 +2609,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=token,
user_agent=user_agent, user_agent=user_agent,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
...@@ -2910,7 +2932,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2910,7 +2932,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
_from_auto=from_auto_class, _from_auto=from_auto_class,
......
...@@ -139,6 +139,8 @@ class AlignTextConfig(PretrainedConfig): ...@@ -139,6 +139,8 @@ class AlignTextConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the text config dict if we are loading from AlignConfig # get the text config dict if we are loading from AlignConfig
...@@ -276,6 +278,8 @@ class AlignVisionConfig(PretrainedConfig): ...@@ -276,6 +278,8 @@ class AlignVisionConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from AlignConfig # get the vision config dict if we are loading from AlignConfig
......
...@@ -228,6 +228,8 @@ class AltCLIPVisionConfig(PretrainedConfig): ...@@ -228,6 +228,8 @@ class AltCLIPVisionConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from AltCLIPConfig # get the vision config dict if we are loading from AltCLIPConfig
......
...@@ -161,6 +161,8 @@ class BlipTextConfig(PretrainedConfig): ...@@ -161,6 +161,8 @@ class BlipTextConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the text config dict if we are loading from BlipConfig # get the text config dict if we are loading from BlipConfig
...@@ -258,6 +260,8 @@ class BlipVisionConfig(PretrainedConfig): ...@@ -258,6 +260,8 @@ class BlipVisionConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from BlipConfig # get the vision config dict if we are loading from BlipConfig
......
...@@ -113,6 +113,8 @@ class Blip2VisionConfig(PretrainedConfig): ...@@ -113,6 +113,8 @@ class Blip2VisionConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from Blip2Config # get the vision config dict if we are loading from Blip2Config
...@@ -229,6 +231,8 @@ class Blip2QFormerConfig(PretrainedConfig): ...@@ -229,6 +231,8 @@ class Blip2QFormerConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the qformer config dict if we are loading from Blip2Config # get the qformer config dict if we are loading from Blip2Config
......
...@@ -144,6 +144,8 @@ class ChineseCLIPTextConfig(PretrainedConfig): ...@@ -144,6 +144,8 @@ class ChineseCLIPTextConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from ChineseCLIPConfig # get the vision config dict if we are loading from ChineseCLIPConfig
...@@ -247,6 +249,8 @@ class ChineseCLIPVisionConfig(PretrainedConfig): ...@@ -247,6 +249,8 @@ class ChineseCLIPVisionConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from ChineseCLIPConfig # get the vision config dict if we are loading from ChineseCLIPConfig
......
...@@ -144,6 +144,8 @@ class ClapTextConfig(PretrainedConfig): ...@@ -144,6 +144,8 @@ class ClapTextConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the text config dict if we are loading from ClapConfig # get the text config dict if we are loading from ClapConfig
...@@ -312,6 +314,8 @@ class ClapAudioConfig(PretrainedConfig): ...@@ -312,6 +314,8 @@ class ClapAudioConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the audio config dict if we are loading from ClapConfig # get the audio config dict if we are loading from ClapConfig
......
...@@ -127,6 +127,8 @@ class CLIPTextConfig(PretrainedConfig): ...@@ -127,6 +127,8 @@ class CLIPTextConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the text config dict if we are loading from CLIPConfig # get the text config dict if we are loading from CLIPConfig
...@@ -230,6 +232,8 @@ class CLIPVisionConfig(PretrainedConfig): ...@@ -230,6 +232,8 @@ class CLIPVisionConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from CLIPConfig # get the vision config dict if we are loading from CLIPConfig
......
...@@ -117,6 +117,8 @@ class CLIPSegTextConfig(PretrainedConfig): ...@@ -117,6 +117,8 @@ class CLIPSegTextConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the text config dict if we are loading from CLIPSegConfig # get the text config dict if we are loading from CLIPSegConfig
...@@ -218,6 +220,8 @@ class CLIPSegVisionConfig(PretrainedConfig): ...@@ -218,6 +220,8 @@ class CLIPSegVisionConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from CLIPSegConfig # get the vision config dict if we are loading from CLIPSegConfig
......
...@@ -131,6 +131,8 @@ class FlavaImageConfig(PretrainedConfig): ...@@ -131,6 +131,8 @@ class FlavaImageConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the image config dict if we are loading from FlavaConfig # get the image config dict if we are loading from FlavaConfig
...@@ -258,6 +260,8 @@ class FlavaTextConfig(PretrainedConfig): ...@@ -258,6 +260,8 @@ class FlavaTextConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the text config dict if we are loading from FlavaConfig # get the text config dict if we are loading from FlavaConfig
...@@ -359,6 +363,8 @@ class FlavaMultimodalConfig(PretrainedConfig): ...@@ -359,6 +363,8 @@ class FlavaMultimodalConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the multimodal config dict if we are loading from FlavaConfig # get the multimodal config dict if we are loading from FlavaConfig
...@@ -442,6 +448,8 @@ class FlavaImageCodebookConfig(PretrainedConfig): ...@@ -442,6 +448,8 @@ class FlavaImageCodebookConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the image codebook config dict if we are loading from FlavaConfig # get the image codebook config dict if we are loading from FlavaConfig
......
...@@ -109,6 +109,8 @@ class GitVisionConfig(PretrainedConfig): ...@@ -109,6 +109,8 @@ class GitVisionConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from GITConfig # get the vision config dict if we are loading from GITConfig
......
...@@ -127,6 +127,8 @@ class GroupViTTextConfig(PretrainedConfig): ...@@ -127,6 +127,8 @@ class GroupViTTextConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the text config dict if we are loading from GroupViTConfig # get the text config dict if we are loading from GroupViTConfig
...@@ -250,6 +252,8 @@ class GroupViTVisionConfig(PretrainedConfig): ...@@ -250,6 +252,8 @@ class GroupViTVisionConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from GroupViTConfig # get the vision config dict if we are loading from GroupViTConfig
......
...@@ -127,6 +127,8 @@ class OwlViTTextConfig(PretrainedConfig): ...@@ -127,6 +127,8 @@ class OwlViTTextConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the text config dict if we are loading from OwlViTConfig # get the text config dict if we are loading from OwlViTConfig
...@@ -230,6 +232,8 @@ class OwlViTVisionConfig(PretrainedConfig): ...@@ -230,6 +232,8 @@ class OwlViTVisionConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from OwlViTConfig # get the vision config dict if we are loading from OwlViTConfig
...@@ -301,6 +305,8 @@ class OwlViTConfig(PretrainedConfig): ...@@ -301,6 +305,8 @@ class OwlViTConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
......
...@@ -153,6 +153,8 @@ class Pix2StructTextConfig(PretrainedConfig): ...@@ -153,6 +153,8 @@ class Pix2StructTextConfig(PretrainedConfig):
def from_pretrained( def from_pretrained(
cls, pretrainehidden_size_name_or_path: Union[str, os.PathLike], **kwargs cls, pretrainehidden_size_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig": ) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrainehidden_size_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrainehidden_size_name_or_path, **kwargs)
# get the text config dict if we are loading from Pix2StructConfig # get the text config dict if we are loading from Pix2StructConfig
...@@ -270,6 +272,8 @@ class Pix2StructVisionConfig(PretrainedConfig): ...@@ -270,6 +272,8 @@ class Pix2StructVisionConfig(PretrainedConfig):
def from_pretrained( def from_pretrained(
cls, pretrainehidden_size_name_or_path: Union[str, os.PathLike], **kwargs cls, pretrainehidden_size_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig": ) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrainehidden_size_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrainehidden_size_name_or_path, **kwargs)
# get the vision config dict if we are loading from Pix2StructConfig # get the vision config dict if we are loading from Pix2StructConfig
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment