Unverified Commit f9a2a9e3 authored by Giovanni Compagnoni's avatar Giovanni Compagnoni Committed by GitHub
Browse files

Extend typing to path-like objects in `PretrainedConfig` and `PreTrainedModel` (#8770)

* update configuration_utils.py typing to allow pathlike objects when sensible

* update modeling_utils.py typing to allow pathlike objects when sensible

* black

* update tokenization_utils_base.py typing to allow pathlike objects when sensible

* update tokenization_utils_fast.py typing to allow pathlike objects when sensible

* update configuration_auto.py typing to allow pathlike objects when sensible

* update configuration_auto.py docstring to allow pathlike objects when sensible

* update tokenization_auto.py docstring to allow pathlike objects when sensible

* black
parent a7d46a06
......@@ -19,7 +19,7 @@
import copy
import json
import os
from typing import Any, Dict, Tuple
from typing import Any, Dict, Tuple, Union
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
from .utils import logging
......@@ -262,13 +262,13 @@ class PretrainedConfig(object):
self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
def save_pretrained(self, save_directory: str):
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
"""
Save a configuration object to the directory ``save_directory``, so that it can be re-loaded using the
:func:`~transformers.PretrainedConfig.from_pretrained` class method.
Args:
save_directory (:obj:`str`):
save_directory (:obj:`str` or :obj:`os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
"""
if os.path.isfile(save_directory):
......@@ -281,13 +281,13 @@ class PretrainedConfig(object):
logger.info("Configuration saved in {}".format(output_config_file))
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "PretrainedConfig":
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
r"""
Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pretrained model
configuration.
Args:
pretrained_model_name_or_path (:obj:`str`):
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
This can be either:
- a string, the `model id` of a pretrained model configuration hosted inside a model repo on
......@@ -297,7 +297,7 @@ class PretrainedConfig(object):
:func:`~transformers.PretrainedConfig.save_pretrained` method, e.g., ``./my_model_directory/``.
- a path or url to a saved configuration JSON `file`, e.g.,
``./my_model_directory/configuration.json``.
cache_dir (:obj:`str`, `optional`):
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
......@@ -346,13 +346,15 @@ class PretrainedConfig(object):
return cls.from_dict(config_dict, **kwargs)
@classmethod
def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
def get_config_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a
:class:`~transformers.PretrainedConfig` using ``from_dict``.
Parameters:
pretrained_model_name_or_path (:obj:`str`):
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
Returns:
......@@ -366,6 +368,7 @@ class PretrainedConfig(object):
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
......@@ -451,12 +454,12 @@ class PretrainedConfig(object):
return config
@classmethod
def from_json_file(cls, json_file: str) -> "PretrainedConfig":
def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig":
"""
Instantiates a :class:`~transformers.PretrainedConfig` from the path to a JSON file of parameters.
Args:
json_file (:obj:`str`):
json_file (:obj:`str` or :obj:`os.PathLike`):
Path to the JSON file containing the parameters.
Returns:
......@@ -467,7 +470,7 @@ class PretrainedConfig(object):
return cls(**config_dict)
@classmethod
def _dict_from_json_file(cls, json_file: str):
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)
......@@ -537,12 +540,12 @@ class PretrainedConfig(object):
config_dict = self.to_dict()
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: str, use_diff: bool = True):
def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
"""
Save this instance to a JSON file.
Args:
json_file_path (:obj:`str`):
json_file_path (:obj:`str` or :obj:`os.PathLike`):
Path to the JSON file in which this configuration instance's parameters will be saved.
use_diff (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to ``True``, only the difference between the config instance and the default
......
......@@ -697,13 +697,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
self.base_model._prune_heads(heads_to_prune)
def save_pretrained(self, save_directory):
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
Arguments:
save_directory (:obj:`str`):
save_directory (:obj:`str` or :obj:`os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
"""
if os.path.isfile(save_directory):
......@@ -741,7 +741,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
logger.info("Model weights saved in {}".format(output_model_file))
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
r"""
Instantiate a pretrained pytorch model from a pre-trained model configuration.
......@@ -756,7 +756,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
weights are discarded.
Parameters:
pretrained_model_name_or_path (:obj:`str`, `optional`):
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`, `optional`):
Can be either:
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
......@@ -772,11 +772,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
arguments ``config`` and ``state_dict``).
model_args (sequence of positional arguments, `optional`):
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
config (:obj:`Union[PretrainedConfig, str]`, `optional`):
config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`):
Can be either:
- an instance of a class derived from :class:`~transformers.PretrainedConfig`,
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
- a string or path valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
Configuration for the model to use instead of an automatically loaded configuation. Configuration can
be automatically loaded when:
......@@ -794,7 +794,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
weights. In this case though, you should check if using
:func:`~transformers.PreTrainedModel.save_pretrained` and
:func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
cache_dir (:obj:`str`, `optional`):
cache_dir (:obj:`Union[str, os.PathLike]`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
......@@ -881,6 +881,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# Load model
if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
# Load from a TF 1.0 checkpoint in priority if from_tf
......
......@@ -274,7 +274,7 @@ class AutoConfig:
List options
Args:
pretrained_model_name_or_path (:obj:`str`):
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either:
- A string, the `model id` of a pretrained model configuration hosted inside a model repo on
......@@ -285,7 +285,7 @@ class AutoConfig:
:meth:`~transformers.PreTrainedModel.save_pretrained` method, e.g., ``./my_model_directory/``.
- A path or url to a saved configuration JSON `file`, e.g.,
``./my_model_directory/configuration.json``.
cache_dir (:obj:`str`, `optional`):
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
......@@ -346,7 +346,7 @@ class AutoConfig:
else:
# Fallback: use pattern matching on the string.
for pattern, config_class in CONFIG_MAPPING.items():
if pattern in pretrained_model_name_or_path:
if pattern in str(pretrained_model_name_or_path):
return config_class.from_dict(config_dict, **kwargs)
raise ValueError(
......
......@@ -502,7 +502,7 @@ AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
deactivated). To train the model, you should first set it back in training mode with ``model.train()``
Args:
pretrained_model_name_or_path:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either:
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
......@@ -533,7 +533,7 @@ AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
weights. In this case though, you should check if using
:func:`~transformers.PreTrainedModel.save_pretrained` and
:func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
cache_dir (:obj:`str`, `optional`):
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
......
......@@ -267,7 +267,7 @@ class AutoTokenizer:
List options
Params:
pretrained_model_name_or_path (:obj:`str`):
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either:
- A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co.
......@@ -283,7 +283,7 @@ class AutoTokenizer:
Will be passed along to the Tokenizer ``__init__()`` method.
config (:class:`~transformers.PreTrainedConfig`, `optional`)
The configuration object used to dertermine the tokenizer class to instantiate.
cache_dir (:obj:`str`, `optional`):
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
......
......@@ -1608,13 +1608,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
raise NotImplementedError()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs):
r"""
Instantiate a :class:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase` (or a derived class) from
a predefined tokenizer.
Args:
pretrained_model_name_or_path (:obj:`str`):
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either:
- A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co.
......@@ -1626,7 +1626,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
- (**Deprecated**, not applicable to all derived classes) A path or url to a single saved vocabulary
file (if and only if the tokenizer only requires a single vocabulary file like Bert or XLNet), e.g.,
``./my_model_directory/vocab.txt``.
cache_dir (:obj:`str`, `optional`):
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the
standard cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
......@@ -1683,6 +1683,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
subfolder = kwargs.pop("subfolder", None)
s3_models = list(cls.max_model_input_sizes.keys())
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
vocab_files = {}
init_configuration = {}
if pretrained_model_name_or_path in s3_models:
......@@ -1904,7 +1905,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
return tokenizer
def save_pretrained(
self, save_directory: str, legacy_format: bool = True, filename_prefix: Optional[str] = None
self,
save_directory: Union[str, os.PathLike],
legacy_format: bool = True,
filename_prefix: Optional[str] = None,
) -> Tuple[str]:
"""
Save the full tokenizer state.
......@@ -1924,7 +1928,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
modifying :obj:`tokenizer.do_lower_case` after creation).
Args:
save_directory (:obj:`str`): The path to a directory where the tokenizer will be saved.
save_directory (:obj:`str` or :obj:`os.PathLike`): The path to a directory where the tokenizer will be saved.
legacy_format (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to save the tokenizer in legacy format (default), i.e. with tokenizer specific vocabulary and a
separate added_tokens files or in the unified JSON file format for the `tokenizers` library. It's only
......@@ -1988,7 +1992,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
def _save_pretrained(
self,
save_directory: str,
save_directory: Union[str, os.PathLike],
file_names: Tuple[str],
legacy_format: bool = True,
filename_prefix: Optional[str] = None,
......
......@@ -498,7 +498,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
def _save_pretrained(
self,
save_directory: str,
save_directory: Union[str, os.PathLike],
file_names: Tuple[str],
legacy_format: bool = True,
filename_prefix: Optional[str] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment