# coding=utf-8 # Copyright 2022 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Configuration base class and utilities.""" import copy import json import os import re from typing import Any, Dict, Tuple, Union from requests import HTTPError from transformers.utils import ( CONFIG_NAME, HUGGINGFACE_CO_RESOLVE_ENDPOINT, EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError, cached_path, hf_bucket_url, is_offline_mode, is_remote_url, logging, ) from . import __version__ logger = logging.get_logger(__name__) _re_configuration_file = re.compile(r"config\.(.*)\.json") class PretrainedConfig: r""" Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations. """ model_type: str = "" def __init__(self, **kwargs): # Name or path to the pretrained checkpoint self._name_or_path = str(kwargs.pop("name_or_path", "")) # Drop the diffusers version info self.diffusers_version = kwargs.pop("diffusers_version", None) @property def name_or_path(self) -> str: return getattr(self, "_name_or_path", None) @name_or_path.setter def name_or_path(self, value): self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding) def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the [`~PretrainedConfig.from_pretrained`] class method. Args: save_directory (`str` or `os.PathLike`): Directory where the configuration JSON file will be saved (will be created if it does not exist). kwargs: Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ if os.path.isfile(save_directory): raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") os.makedirs(save_directory, exist_ok=True) # If we save using the predefined names, we can load using `from_pretrained` output_config_file = os.path.join(save_directory, CONFIG_NAME) self.to_json_file(output_config_file, use_diff=True) logger.info(f"Configuration saved in {output_config_file}") @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": r""" Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration. Args: pretrained_model_name_or_path (`str` or `os.PathLike`): This can be either: - a string, the *model id* of a pretrained model configuration hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. - a path to a *directory* containing a configuration file saved using the [`~PretrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`. - a path or url to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`. cache_dir (`str` or `os.PathLike`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force to (re-)download the configuration files and override the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `diffusers-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. return_unused_kwargs (`bool`, *optional*, defaults to `False`): If `False`, then this function returns just the final configuration object. If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part of `kwargs` which has not been used to update `config` and is otherwise ignored. kwargs (`Dict[str, Any]`, *optional*): The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter. Passing `use_auth_token=True` is required when you want to use a private model. Returns: [`PretrainedConfig`]: The configuration object instantiated from this pretrained model. Examples: ```python # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a # derived class: BertConfig config = BertConfig.from_pretrained( "bert-base-uncased" ) # Download configuration from huggingface.co and cache. config = BertConfig.from_pretrained( "./test/saved_model/" ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')* config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json") config = BertConfig.from_pretrained("bert-base-uncased", output_attentions=True, foo=False) assert config.output_attentions == True config, unused_kwargs = BertConfig.from_pretrained( "bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True ) assert config.output_attentions == True assert unused_kwargs == {"foo": False} ```""" config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: logger.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." ) return cls.from_dict(config_dict, **kwargs) @classmethod def get_config_dict( cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a [`PretrainedConfig`] using `from_dict`. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`): The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. Returns: `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object. """ # Get config dict associated with the base config file config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs) return config_dict, kwargs @classmethod def _get_config_dict( cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs ) -> Tuple[Dict[str, Any], Dict[str, Any]]: cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) use_auth_token = kwargs.pop("use_auth_token", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) user_agent = {"file_type": "config"} if is_offline_mode() and not local_files_only: logger.info("Offline mode: forcing local_files_only=True") local_files_only = True pretrained_model_name_or_path = str(pretrained_model_name_or_path) if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path else: configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if os.path.isdir(pretrained_model_name_or_path): config_file = os.path.join(pretrained_model_name_or_path, configuration_file) else: config_file = hf_bucket_url( pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None ) try: # Load from URL or cache if already cached resolved_config_file = cached_path( config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, ) except RepositoryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on " "'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having " "permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass " "`use_auth_token=True`." ) except RevisionNotFoundError: raise EnvironmentError( f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this " f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for " "available revisions." ) except EntryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}." ) except HTTPError as err: raise EnvironmentError( f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" ) except ValueError: raise EnvironmentError( f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in" f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory" f" containing a {configuration_file} file.\nCheckout your internet connection or see how to run the" " library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." ) except EnvironmentError: raise EnvironmentError( f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from " "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"containing a {configuration_file} file" ) try: # Load config dict config_dict = cls._dict_from_json_file(resolved_config_file) except (json.JSONDecodeError, UnicodeDecodeError): raise EnvironmentError( f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file." ) if resolved_config_file == config_file: logger.info(f"loading configuration file {config_file}") else: logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}") return config_dict, kwargs @classmethod def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig": """ Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters. Args: config_dict (`Dict[str, Any]`): Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method. kwargs (`Dict[str, Any]`): Additional parameters from which to initialize the configuration object. Returns: [`PretrainedConfig`]: The configuration object instantiated from those parameters. """ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) # Those arguments may be passed along for our internal telemetry. # We remove them so they don't appear in `return_unused_kwargs`. config = cls(**config_dict) to_remove = [] for key, value in kwargs.items(): if hasattr(config, key): setattr(config, key, value) to_remove.append(key) for key in to_remove: kwargs.pop(key, None) logger.info(f"Model config {config}") if return_unused_kwargs: return config, kwargs else: return config @classmethod def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig": """ Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters. Args: json_file (`str` or `os.PathLike`): Path to the JSON file containing the parameters. Returns: [`PretrainedConfig`]: The configuration object instantiated from that JSON file. """ config_dict = cls._dict_from_json_file(json_file) return cls(**config_dict) @classmethod def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): with open(json_file, "r", encoding="utf-8") as reader: text = reader.read() return json.loads(text) def __eq__(self, other): return self.__dict__ == other.__dict__ def __repr__(self): return f"{self.__class__.__name__} {self.to_json_string()}" def to_diff_dict(self) -> Dict[str, Any]: """ Removes all attributes from config which correspond to the default config attributes for better readability and serializes to a Python dictionary. Returns: `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, """ config_dict = self.to_dict() # get the default config dict default_config_dict = PretrainedConfig().to_dict() # get class specific config dict class_config_dict = self.__class__().to_dict() serializable_config_dict = {} # only serialize values that differ from the default config for key, value in config_dict.items(): if ( key not in default_config_dict or key == "diffusers_version" or value != default_config_dict[key] or (key in class_config_dict and value != class_config_dict[key]) ): serializable_config_dict[key] = value self.dict_torch_dtype_to_str(serializable_config_dict) return serializable_config_dict def to_dict(self) -> Dict[str, Any]: """ Serializes this instance to a Python dictionary. Returns: `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. """ output = copy.deepcopy(self.__dict__) if hasattr(self.__class__, "model_type"): output["model_type"] = self.__class__.model_type if "_auto_class" in output: del output["_auto_class"] # Transformers version when serializing the model output["diffusers_version"] = __version__ self.dict_torch_dtype_to_str(output) return output def to_json_string(self, use_diff: bool = True) -> str: """ Serializes this instance to a JSON string. Args: use_diff (`bool`, *optional*, defaults to `True`): If set to `True`, only the difference between the config instance and the default `PretrainedConfig()` is serialized to JSON string. Returns: `str`: String containing all the attributes that make up this configuration instance in JSON format. """ if use_diff is True: config_dict = self.to_diff_dict() else: config_dict = self.to_dict() return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True): """ Save this instance to a JSON file. Args: json_file_path (`str` or `os.PathLike`): Path to the JSON file in which this configuration instance's parameters will be saved. use_diff (`bool`, *optional*, defaults to `True`): If set to `True`, only the difference between the config instance and the default `PretrainedConfig()` is serialized to JSON file. """ with open(json_file_path, "w", encoding="utf-8") as writer: writer.write(self.to_json_string(use_diff=use_diff)) def update(self, config_dict: Dict[str, Any]): """ Updates attributes of this class with attributes from `config_dict`. Args: config_dict (`Dict[str, Any]`): Dictionary of attributes that should be updated for this class. """ for key, value in config_dict.items(): setattr(self, key, value) def update_from_string(self, update_str: str): """ Updates attributes of this class with attributes from `update_str`. The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example: "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" The keys to change have to already exist in the config object. Args: update_str (`str`): String with attributes that should be updated for this class. """ d = dict(x.split("=") for x in update_str.split(",")) for k, v in d.items(): if not hasattr(self, k): raise ValueError(f"key {k} isn't in the original config dict") old_v = getattr(self, k) if isinstance(old_v, bool): if v.lower() in ["true", "1", "y", "yes"]: v = True elif v.lower() in ["false", "0", "n", "no"]: v = False else: raise ValueError(f"can't derive true or false from {v} (key {k})") elif isinstance(old_v, int): v = int(v) elif isinstance(old_v, float): v = float(v) elif not isinstance(old_v, str): raise ValueError( f"You can only update int, float, bool or string values in the config, got {v} for key {k}" ) setattr(self, k, v) def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None: """ Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None, converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"* string, which can then be stored in the json format. """ if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str): d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1] for value in d.values(): if isinstance(value, dict): self.dict_torch_dtype_to_str(value)