Commit 009fcb0e authored by Lysandre's avatar Lysandre Committed by Lysandre Debut
Browse files

Configuration utils

parent 11b13e94
...@@ -40,12 +40,17 @@ class PretrainedConfig(object): ...@@ -40,12 +40,17 @@ class PretrainedConfig(object):
- ``pretrained_config_archive_map``: a python ``dict`` with `shortcut names` (string) as keys and `url` (string) of associated pretrained model configurations as values. - ``pretrained_config_archive_map``: a python ``dict`` with `shortcut names` (string) as keys and `url` (string) of associated pretrained model configurations as values.
- ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`. - ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`.
Parameters: Args:
``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. finetuning_task (:obj:`string` or :obj:`None`, `optional`, defaults to :obj:`None`):
``num_labels``: integer, default `2`. Number of classes to use when the model is a classification model (sequences/tokens) Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
``output_attentions``: boolean, default `False`. Should the model returns attentions weights. num_labels (:obj:`int`, `optional`, defaults to `2`):
``output_hidden_states``: string, default `False`. Should the model returns all hidden-states. Number of classes to use when the model is a classification model (sequences/tokens)
``torchscript``: string, default `False`. Is the model used with Torchscript. output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`):
Should the model returns attentions weights.
output_hidden_states (:obj:`string`, `optional`, defaults to :obj:`False`):
Should the model returns all hidden-states.
torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`):
Is the model used with Torchscript (for PyTorch models).
""" """
pretrained_config_archive_map = {} # type: Dict[str, str] pretrained_config_archive_map = {} # type: Dict[str, str]
model_type = "" # type: str model_type = "" # type: str
...@@ -93,8 +98,13 @@ class PretrainedConfig(object): ...@@ -93,8 +98,13 @@ class PretrainedConfig(object):
raise err raise err
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" Save a configuration object to the directory `save_directory`, so that it """
can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method. Save a configuration object to the directory `save_directory`, so that it
can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
Args:
save_directory (:obj:`string`):
Directory where the configuration JSON file will be saved.
""" """
assert os.path.isdir( assert os.path.isdir(
save_directory save_directory
...@@ -107,40 +117,45 @@ class PretrainedConfig(object): ...@@ -107,40 +117,45 @@ class PretrainedConfig(object):
logger.info("Configuration saved in {}".format(output_config_file)) logger.info("Configuration saved in {}".format(output_config_file))
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> 'PretrainedConfig':
r""" Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration. r"""
Parameters: Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration.
pretrained_model_name_or_path: either:
Args:
- a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``. pretrained_model_name_or_path (:obj:`string`):
- a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. either:
- a path to a `directory` containing a configuration file saved using the :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. - a string with the `shortcut name` of a pre-trained model configuration to load from cache or
- a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``. download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to
cache_dir: (`optional`) string: our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a configuration file saved using the
:func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
- a path or url to a saved configuration JSON `file`, e.g.:
``./my_model_directory/configuration.json``.
cache_dir (:obj:`string`, `optional`):
Path to a directory in which a downloaded pre-trained model Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used. configuration should be cached if the standard cache should not be used.
kwargs (:obj:`Dict[str, any]`, `optional`):
kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading. The values in kwargs of any keys which are configuration attributes will be used to override the loaded
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is
- The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. controlled by the `return_unused_kwargs` keyword parameter.
- Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter. force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Force to (re-)download the model weights and configuration files and override the cached versions if they exist.
force_download: (`optional`) boolean, default False: resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists. Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies (:obj:`Dict`, `optional`):
proxies: (`optional`) dict, default None: A dictionary of proxy servers to use by protocol or endpoint, e.g.:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.`
The proxies are used on each request. The proxies are used on each request.
return_unused_kwargs: (`optional`) bool: return_unused_kwargs: (`optional`) bool:
If False, then this function returns just the final configuration object.
If True, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` is a
dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part
of kwargs which has not been used to update `config` and is otherwise ignored.
- If False, then this function returns just the final configuration object. Returns:
- If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored. :class:`PretrainedConfig`: An instance of a configuration object
Examples:: Examples::
...@@ -169,9 +184,14 @@ class PretrainedConfig(object): ...@@ -169,9 +184,14 @@ class PretrainedConfig(object):
for instantiating a Config using `from_dict`. for instantiating a Config using `from_dict`.
Parameters: Parameters:
pretrained_config_archive_map: (`optional`) Dict: pretrained_model_name_or_path (:obj:`string`):
A map of `shortcut names` to `url`. The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
By default, will use the current class attribute. pretrained_config_archive_map: (:obj:`Dict[str, str]`, `optional`) Dict:
A map of `shortcut names` to `url`. By default, will use the current class attribute.
Returns:
:obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object.
""" """
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
...@@ -235,8 +255,21 @@ class PretrainedConfig(object): ...@@ -235,8 +255,21 @@ class PretrainedConfig(object):
return config_dict, kwargs return config_dict, kwargs
@classmethod @classmethod
def from_dict(cls, config_dict: Dict, **kwargs): def from_dict(cls, config_dict: Dict, **kwargs) -> 'PretrainedConfig':
"""Constructs a `Config` from a Python dictionary of parameters.""" """
Constructs a `Config` from a Python dictionary of parameters.
Args:
config_dict (:obj:`Dict[str, any]`):
Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved
from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict`
method.
kwargs (:obj:`Dict[str, any]`):
Additional parameters from which to initialize the configuration object.
Returns:
:class:`PretrainedConfig`: An instance of a configuration object
"""
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
config = cls(**config_dict) config = cls(**config_dict)
...@@ -260,8 +293,18 @@ class PretrainedConfig(object): ...@@ -260,8 +293,18 @@ class PretrainedConfig(object):
return config return config
@classmethod @classmethod
def from_json_file(cls, json_file: str): def from_json_file(cls, json_file: str) -> 'PretrainedConfig':
"""Constructs a `Config` from the path to a json file of parameters.""" """
Constructs a `Config` from the path to a json file of parameters.
Args:
json_file (:obj:`string`):
Path to the JSON file containing the parameters.
Returns:
:class:`PretrainedConfig`: An instance of a configuration object
"""
config_dict = cls._dict_from_json_file(json_file) config_dict = cls._dict_from_json_file(json_file)
return cls(**config_dict) return cls(**config_dict)
...@@ -278,17 +321,33 @@ class PretrainedConfig(object): ...@@ -278,17 +321,33 @@ class PretrainedConfig(object):
return "{} {}".format(self.__class__.__name__, self.to_json_string()) return "{} {}".format(self.__class__.__name__, self.to_json_string())
def to_dict(self): def to_dict(self):
"""Serializes this instance to a Python dictionary.""" """
Serializes this instance to a Python dictionary.
Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__) output = copy.deepcopy(self.__dict__)
if hasattr(self.__class__, "model_type"): if hasattr(self.__class__, "model_type"):
output["model_type"] = self.__class__.model_type output["model_type"] = self.__class__.model_type
return output return output
def to_json_string(self): def to_json_string(self):
"""Serializes this instance to a JSON string.""" """
Serializes this instance to a JSON string.
Returns:
:obj:`string`: String containing all the attributes that make up this configuration instance in JSON format.
"""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path): def to_json_file(self, json_file_path):
""" Save this instance to a json file.""" """
Save this instance to a json file.
Args:
json_file_path (:obj:`string`):
Path to the JSON file in which this configuration instance's parameters will be saved.
"""
with open(json_file_path, "w", encoding="utf-8") as writer: with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string()) writer.write(self.to_json_string())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment