Unverified Commit f9a2a9e3 authored by Giovanni Compagnoni's avatar Giovanni Compagnoni Committed by GitHub
Browse files

Extend typing to path-like objects in `PretrainedConfig` and `PreTrainedModel` (#8770)

* update configuration_utils.py typing to allow pathlike objects when sensible

* update modeling_utils.py typing to allow pathlike objects when sensible

* black

* update tokenization_utils_base.py typing to allow pathlike objects when sensible

* update tokenization_utils_fast.py typing to allow pathlike objects when sensible

* update configuration_auto.py typing to allow pathlike objects when sensible

* update configuration_auto.py docstring to allow pathlike objects when sensible

* update tokenization_auto.py docstring to allow pathlike objects when sensible

* black
parent a7d46a06
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
import copy import copy
import json import json
import os import os
from typing import Any, Dict, Tuple from typing import Any, Dict, Tuple, Union
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
from .utils import logging from .utils import logging
...@@ -262,13 +262,13 @@ class PretrainedConfig(object): ...@@ -262,13 +262,13 @@ class PretrainedConfig(object):
self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)} self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
def save_pretrained(self, save_directory: str): def save_pretrained(self, save_directory: Union[str, os.PathLike]):
""" """
Save a configuration object to the directory ``save_directory``, so that it can be re-loaded using the Save a configuration object to the directory ``save_directory``, so that it can be re-loaded using the
:func:`~transformers.PretrainedConfig.from_pretrained` class method. :func:`~transformers.PretrainedConfig.from_pretrained` class method.
Args: Args:
save_directory (:obj:`str`): save_directory (:obj:`str` or :obj:`os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist). Directory where the configuration JSON file will be saved (will be created if it does not exist).
""" """
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
...@@ -281,13 +281,13 @@ class PretrainedConfig(object): ...@@ -281,13 +281,13 @@ class PretrainedConfig(object):
logger.info("Configuration saved in {}".format(output_config_file)) logger.info("Configuration saved in {}".format(output_config_file))
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
r""" r"""
Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pretrained model Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pretrained model
configuration. configuration.
Args: Args:
pretrained_model_name_or_path (:obj:`str`): pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
This can be either: This can be either:
- a string, the `model id` of a pretrained model configuration hosted inside a model repo on - a string, the `model id` of a pretrained model configuration hosted inside a model repo on
...@@ -297,7 +297,7 @@ class PretrainedConfig(object): ...@@ -297,7 +297,7 @@ class PretrainedConfig(object):
:func:`~transformers.PretrainedConfig.save_pretrained` method, e.g., ``./my_model_directory/``. :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g., ``./my_model_directory/``.
- a path or url to a saved configuration JSON `file`, e.g., - a path or url to a saved configuration JSON `file`, e.g.,
``./my_model_directory/configuration.json``. ``./my_model_directory/configuration.json``.
cache_dir (:obj:`str`, `optional`): cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used. standard cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
...@@ -346,13 +346,15 @@ class PretrainedConfig(object): ...@@ -346,13 +346,15 @@ class PretrainedConfig(object):
return cls.from_dict(config_dict, **kwargs) return cls.from_dict(config_dict, **kwargs)
@classmethod @classmethod
def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]: def get_config_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
""" """
From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a
:class:`~transformers.PretrainedConfig` using ``from_dict``. :class:`~transformers.PretrainedConfig` using ``from_dict``.
Parameters: Parameters:
pretrained_model_name_or_path (:obj:`str`): pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
Returns: Returns:
...@@ -366,6 +368,7 @@ class PretrainedConfig(object): ...@@ -366,6 +368,7 @@ class PretrainedConfig(object):
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
...@@ -451,12 +454,12 @@ class PretrainedConfig(object): ...@@ -451,12 +454,12 @@ class PretrainedConfig(object):
return config return config
@classmethod @classmethod
def from_json_file(cls, json_file: str) -> "PretrainedConfig": def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig":
""" """
Instantiates a :class:`~transformers.PretrainedConfig` from the path to a JSON file of parameters. Instantiates a :class:`~transformers.PretrainedConfig` from the path to a JSON file of parameters.
Args: Args:
json_file (:obj:`str`): json_file (:obj:`str` or :obj:`os.PathLike`):
Path to the JSON file containing the parameters. Path to the JSON file containing the parameters.
Returns: Returns:
...@@ -467,7 +470,7 @@ class PretrainedConfig(object): ...@@ -467,7 +470,7 @@ class PretrainedConfig(object):
return cls(**config_dict) return cls(**config_dict)
@classmethod @classmethod
def _dict_from_json_file(cls, json_file: str): def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
with open(json_file, "r", encoding="utf-8") as reader: with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read() text = reader.read()
return json.loads(text) return json.loads(text)
...@@ -537,12 +540,12 @@ class PretrainedConfig(object): ...@@ -537,12 +540,12 @@ class PretrainedConfig(object):
config_dict = self.to_dict() config_dict = self.to_dict()
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: str, use_diff: bool = True): def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
""" """
Save this instance to a JSON file. Save this instance to a JSON file.
Args: Args:
json_file_path (:obj:`str`): json_file_path (:obj:`str` or :obj:`os.PathLike`):
Path to the JSON file in which this configuration instance's parameters will be saved. Path to the JSON file in which this configuration instance's parameters will be saved.
use_diff (:obj:`bool`, `optional`, defaults to :obj:`True`): use_diff (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to ``True``, only the difference between the config instance and the default If set to ``True``, only the difference between the config instance and the default
......
...@@ -697,13 +697,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -697,13 +697,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
self.base_model._prune_heads(heads_to_prune) self.base_model._prune_heads(heads_to_prune)
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory: Union[str, os.PathLike]):
""" """
Save a model and its configuration file to a directory, so that it can be re-loaded using the Save a model and its configuration file to a directory, so that it can be re-loaded using the
`:func:`~transformers.PreTrainedModel.from_pretrained`` class method. `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
Arguments: Arguments:
save_directory (:obj:`str`): save_directory (:obj:`str` or :obj:`os.PathLike`):
Directory to which to save. Will be created if it doesn't exist. Directory to which to save. Will be created if it doesn't exist.
""" """
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
...@@ -741,7 +741,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -741,7 +741,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
logger.info("Model weights saved in {}".format(output_model_file)) logger.info("Model weights saved in {}".format(output_model_file))
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
r""" r"""
Instantiate a pretrained pytorch model from a pre-trained model configuration. Instantiate a pretrained pytorch model from a pre-trained model configuration.
...@@ -756,7 +756,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -756,7 +756,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
weights are discarded. weights are discarded.
Parameters: Parameters:
pretrained_model_name_or_path (:obj:`str`, `optional`): pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`, `optional`):
Can be either: Can be either:
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
...@@ -772,11 +772,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -772,11 +772,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
arguments ``config`` and ``state_dict``). arguments ``config`` and ``state_dict``).
model_args (sequence of positional arguments, `optional`): model_args (sequence of positional arguments, `optional`):
All remaning positional arguments will be passed to the underlying model's ``__init__`` method. All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
config (:obj:`Union[PretrainedConfig, str]`, `optional`): config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`):
Can be either: Can be either:
- an instance of a class derived from :class:`~transformers.PretrainedConfig`, - an instance of a class derived from :class:`~transformers.PretrainedConfig`,
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`. - a string or path valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
Configuration for the model to use instead of an automatically loaded configuation. Configuration can Configuration for the model to use instead of an automatically loaded configuation. Configuration can
be automatically loaded when: be automatically loaded when:
...@@ -794,7 +794,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -794,7 +794,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
weights. In this case though, you should check if using weights. In this case though, you should check if using
:func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.save_pretrained` and
:func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
cache_dir (:obj:`str`, `optional`): cache_dir (:obj:`Union[str, os.PathLike]`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used. standard cache should not be used.
from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`): from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
...@@ -881,6 +881,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -881,6 +881,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# Load model # Load model
if pretrained_model_name_or_path is not None: if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")): if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
# Load from a TF 1.0 checkpoint in priority if from_tf # Load from a TF 1.0 checkpoint in priority if from_tf
......
...@@ -274,7 +274,7 @@ class AutoConfig: ...@@ -274,7 +274,7 @@ class AutoConfig:
List options List options
Args: Args:
pretrained_model_name_or_path (:obj:`str`): pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either: Can be either:
- A string, the `model id` of a pretrained model configuration hosted inside a model repo on - A string, the `model id` of a pretrained model configuration hosted inside a model repo on
...@@ -285,7 +285,7 @@ class AutoConfig: ...@@ -285,7 +285,7 @@ class AutoConfig:
:meth:`~transformers.PreTrainedModel.save_pretrained` method, e.g., ``./my_model_directory/``. :meth:`~transformers.PreTrainedModel.save_pretrained` method, e.g., ``./my_model_directory/``.
- A path or url to a saved configuration JSON `file`, e.g., - A path or url to a saved configuration JSON `file`, e.g.,
``./my_model_directory/configuration.json``. ``./my_model_directory/configuration.json``.
cache_dir (:obj:`str`, `optional`): cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used. standard cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
...@@ -346,7 +346,7 @@ class AutoConfig: ...@@ -346,7 +346,7 @@ class AutoConfig:
else: else:
# Fallback: use pattern matching on the string. # Fallback: use pattern matching on the string.
for pattern, config_class in CONFIG_MAPPING.items(): for pattern, config_class in CONFIG_MAPPING.items():
if pattern in pretrained_model_name_or_path: if pattern in str(pretrained_model_name_or_path):
return config_class.from_dict(config_dict, **kwargs) return config_class.from_dict(config_dict, **kwargs)
raise ValueError( raise ValueError(
......
...@@ -502,7 +502,7 @@ AUTO_MODEL_PRETRAINED_DOCSTRING = r""" ...@@ -502,7 +502,7 @@ AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
deactivated). To train the model, you should first set it back in training mode with ``model.train()`` deactivated). To train the model, you should first set it back in training mode with ``model.train()``
Args: Args:
pretrained_model_name_or_path: pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either: Can be either:
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
...@@ -533,7 +533,7 @@ AUTO_MODEL_PRETRAINED_DOCSTRING = r""" ...@@ -533,7 +533,7 @@ AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
weights. In this case though, you should check if using weights. In this case though, you should check if using
:func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.save_pretrained` and
:func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
cache_dir (:obj:`str`, `optional`): cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used. standard cache should not be used.
from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`): from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
......
...@@ -267,7 +267,7 @@ class AutoTokenizer: ...@@ -267,7 +267,7 @@ class AutoTokenizer:
List options List options
Params: Params:
pretrained_model_name_or_path (:obj:`str`): pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either: Can be either:
- A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co. - A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co.
...@@ -283,7 +283,7 @@ class AutoTokenizer: ...@@ -283,7 +283,7 @@ class AutoTokenizer:
Will be passed along to the Tokenizer ``__init__()`` method. Will be passed along to the Tokenizer ``__init__()`` method.
config (:class:`~transformers.PreTrainedConfig`, `optional`) config (:class:`~transformers.PreTrainedConfig`, `optional`)
The configuration object used to dertermine the tokenizer class to instantiate. The configuration object used to dertermine the tokenizer class to instantiate.
cache_dir (:obj:`str`, `optional`): cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used. standard cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
......
...@@ -1608,13 +1608,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1608,13 +1608,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
raise NotImplementedError() raise NotImplementedError()
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs):
r""" r"""
Instantiate a :class:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase` (or a derived class) from Instantiate a :class:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase` (or a derived class) from
a predefined tokenizer. a predefined tokenizer.
Args: Args:
pretrained_model_name_or_path (:obj:`str`): pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either: Can be either:
- A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co. - A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co.
...@@ -1626,7 +1626,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1626,7 +1626,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
- (**Deprecated**, not applicable to all derived classes) A path or url to a single saved vocabulary - (**Deprecated**, not applicable to all derived classes) A path or url to a single saved vocabulary
file (if and only if the tokenizer only requires a single vocabulary file like Bert or XLNet), e.g., file (if and only if the tokenizer only requires a single vocabulary file like Bert or XLNet), e.g.,
``./my_model_directory/vocab.txt``. ``./my_model_directory/vocab.txt``.
cache_dir (:obj:`str`, `optional`): cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the
standard cache should not be used. standard cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
...@@ -1683,6 +1683,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1683,6 +1683,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
s3_models = list(cls.max_model_input_sizes.keys()) s3_models = list(cls.max_model_input_sizes.keys())
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
vocab_files = {} vocab_files = {}
init_configuration = {} init_configuration = {}
if pretrained_model_name_or_path in s3_models: if pretrained_model_name_or_path in s3_models:
...@@ -1904,7 +1905,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1904,7 +1905,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
return tokenizer return tokenizer
def save_pretrained( def save_pretrained(
self, save_directory: str, legacy_format: bool = True, filename_prefix: Optional[str] = None self,
save_directory: Union[str, os.PathLike],
legacy_format: bool = True,
filename_prefix: Optional[str] = None,
) -> Tuple[str]: ) -> Tuple[str]:
""" """
Save the full tokenizer state. Save the full tokenizer state.
...@@ -1924,7 +1928,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1924,7 +1928,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
modifying :obj:`tokenizer.do_lower_case` after creation). modifying :obj:`tokenizer.do_lower_case` after creation).
Args: Args:
save_directory (:obj:`str`): The path to a directory where the tokenizer will be saved. save_directory (:obj:`str` or :obj:`os.PathLike`): The path to a directory where the tokenizer will be saved.
legacy_format (:obj:`bool`, `optional`, defaults to :obj:`True`): legacy_format (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to save the tokenizer in legacy format (default), i.e. with tokenizer specific vocabulary and a Whether to save the tokenizer in legacy format (default), i.e. with tokenizer specific vocabulary and a
separate added_tokens files or in the unified JSON file format for the `tokenizers` library. It's only separate added_tokens files or in the unified JSON file format for the `tokenizers` library. It's only
...@@ -1988,7 +1992,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1988,7 +1992,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
def _save_pretrained( def _save_pretrained(
self, self,
save_directory: str, save_directory: Union[str, os.PathLike],
file_names: Tuple[str], file_names: Tuple[str],
legacy_format: bool = True, legacy_format: bool = True,
filename_prefix: Optional[str] = None, filename_prefix: Optional[str] = None,
......
...@@ -498,7 +498,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -498,7 +498,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
def _save_pretrained( def _save_pretrained(
self, self,
save_directory: str, save_directory: Union[str, os.PathLike],
file_names: Tuple[str], file_names: Tuple[str],
legacy_format: bool = True, legacy_format: bool = True,
filename_prefix: Optional[str] = None, filename_prefix: Optional[str] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment