configuration_utils.py 10.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Patrick von Platen's avatar
Patrick von Platen committed
16
""" ConfigMixinuration base class and utilities."""
17
18
19


import copy
Patrick von Platen's avatar
improve  
Patrick von Platen committed
20
import inspect
21
22
23
24
25
import json
import os
import re
from typing import Any, Dict, Tuple, Union

26
from huggingface_hub import hf_hub_download
Patrick von Platen's avatar
Patrick von Platen committed
27
from requests import HTTPError
28

Patrick von Platen's avatar
Patrick von Platen committed
29
from . import __version__
30
31
from .utils import (
    DIFFUSERS_CACHE,
Patrick von Platen's avatar
Patrick von Platen committed
32
    HUGGINGFACE_CO_RESOLVE_ENDPOINT,
33
34
35
36
37
38
    EntryNotFoundError,
    RepositoryNotFoundError,
    RevisionNotFoundError,
    logging,
)

39

40
41
42
43
44
logger = logging.get_logger(__name__)

_re_configuration_file = re.compile(r"config\.(.*)\.json")


Patrick von Platen's avatar
Patrick von Platen committed
45
class ConfigMixin:
46
47
48
49
50
    r"""
    Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
    methods for loading/downloading/saving configurations.

    """
51
    config_name = None
52

53
54
55
56
    def register(self, **kwargs):
        if self.config_name is None:
            raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
        kwargs["_class_name"] = self.__class__.__name__
57
58
        kwargs["_diffusers_version"] = __version__

59
60
61
62
63
64
        for key, value in kwargs.items():
            try:
                setattr(self, key, value)
            except AttributeError as err:
                logger.error(f"Can't set {key} with value {value} for {self}")
                raise err
65

66
67
        if not hasattr(self, "_dict_to_save"):
            self._dict_to_save = {}
68

69
        self._dict_to_save.update(kwargs)
70

71
    def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
72
73
        """
        Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
Patrick von Platen's avatar
Patrick von Platen committed
74
        [`~ConfigMixin.from_config`] class method.
75
76
77
78
79
80
81
82
83
84
85
86

        Args:
            save_directory (`str` or `os.PathLike`):
                Directory where the configuration JSON file will be saved (will be created if it does not exist).
            kwargs:
                Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
        """
        if os.path.isfile(save_directory):
            raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")

        os.makedirs(save_directory, exist_ok=True)

87
88
        # If we save using the predefined names, we can load using `from_config`
        output_config_file = os.path.join(save_directory, self.config_name)
89

90
        self.to_json_file(output_config_file)
Patrick von Platen's avatar
Patrick von Platen committed
91
        logger.info(f"ConfigMixinuration saved in {output_config_file}")
92

93
94
    @classmethod
    def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
Patrick von Platen's avatar
Patrick von Platen committed
95
        config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
96
97
98
99
100
101
102
103
104
105

        init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)

        model = cls(**init_dict)

        if return_unused_kwargs:
            return model, unused_kwargs
        else:
            return model

106
    @classmethod
Patrick von Platen's avatar
Patrick von Platen committed
107
108
    def get_config_dict(
        cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
109
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
110
        cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
111
112
113
114
115
116
117
118
119
120
121
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        use_auth_token = kwargs.pop("use_auth_token", None)
        local_files_only = kwargs.pop("local_files_only", False)
        revision = kwargs.pop("revision", None)

        user_agent = {"file_type": "config"}

        pretrained_model_name_or_path = str(pretrained_model_name_or_path)

122
123
124
125
126
127
        if cls.config_name is None:
            raise ValueError(
                "`self.config_name` is not defined. Note that one should not load a config from "
                "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
            )

128
129
130
131
132
133
        if os.path.isfile(pretrained_model_name_or_path):
            config_file = pretrained_model_name_or_path
        elif os.path.isdir(pretrained_model_name_or_path):
            if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
                # Load from a PyTorch checkpoint
                config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
134
            else:
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
                raise EnvironmentError(
                    f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
                )
        else:
            try:
                # Load from URL or cache if already cached
                config_file = hf_hub_download(
                    pretrained_model_name_or_path,
                    filename=cls.config_name,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                    use_auth_token=use_auth_token,
                    user_agent=user_agent,
151
152
                )

153
154
            except RepositoryNotFoundError:
                raise EnvironmentError(
Patrick von Platen's avatar
Patrick von Platen committed
155
156
157
158
                    f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed"
                    " on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token"
                    " having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and"
                    " pass `use_auth_token=True`."
159
160
161
                )
            except RevisionNotFoundError:
                raise EnvironmentError(
Patrick von Platen's avatar
Patrick von Platen committed
162
163
164
                    f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
                    " this model name. Check the model page at"
                    f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
165
166
167
168
169
170
171
                )
            except EntryNotFoundError:
                raise EnvironmentError(
                    f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
                )
            except HTTPError as err:
                raise EnvironmentError(
Patrick von Platen's avatar
Patrick von Platen committed
172
173
                    "There was a specific connection error when trying to load"
                    f" {pretrained_model_name_or_path}:\n{err}"
174
175
176
                )
            except ValueError:
                raise EnvironmentError(
Patrick von Platen's avatar
Patrick von Platen committed
177
178
179
180
181
                    f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
                    f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
                    f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
                    " run the library in offline mode at"
                    " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
182
183
184
185
186
187
188
189
                )
            except EnvironmentError:
                raise EnvironmentError(
                    f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
                    "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
                    f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
                    f"containing a {cls.config_name} file"
                )
190

191
192
193
194
        try:
            # Load config dict
            config_dict = cls._dict_from_json_file(config_file)
        except (json.JSONDecodeError, UnicodeDecodeError):
Patrick von Platen's avatar
Patrick von Platen committed
195
            raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
196

patil-suraj's avatar
patil-suraj committed
197
        return config_dict
198

patil-suraj's avatar
patil-suraj committed
199
200
    @classmethod
    def extract_init_dict(cls, config_dict, **kwargs):
201
202
        expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
        expected_keys.remove("self")
patil-suraj's avatar
patil-suraj committed
203
        init_dict = {}
Patrick von Platen's avatar
improve  
Patrick von Platen committed
204
205
206
        for key in expected_keys:
            if key in kwargs:
                # overwrite key
patil-suraj's avatar
patil-suraj committed
207
208
209
210
                init_dict[key] = kwargs.pop(key)
            elif key in config_dict:
                # use value from config dict
                init_dict[key] = config_dict.pop(key)
Patrick von Platen's avatar
improve  
Patrick von Platen committed
211

patil-suraj's avatar
patil-suraj committed
212
        unused_kwargs = config_dict.update(kwargs)
anton-l's avatar
Style  
anton-l committed
213

patil-suraj's avatar
patil-suraj committed
214
        passed_keys = set(init_dict.keys())
215
        if len(expected_keys - passed_keys) > 0:
Patrick von Platen's avatar
Patrick von Platen committed
216
            logger.warning(
Patrick von Platen's avatar
improve  
Patrick von Platen committed
217
                f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
218
            )
219

patil-suraj's avatar
patil-suraj committed
220
        return init_dict, unused_kwargs
Patrick von Platen's avatar
Patrick von Platen committed
221

222
223
224
225
226
227
    @classmethod
    def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
        with open(json_file, "r", encoding="utf-8") as reader:
            text = reader.read()
        return json.loads(text)

anton-l's avatar
anton-l committed
228
229
    # def __eq__(self, other):
    #    return self.__dict__ == other.__dict__
230

anton-l's avatar
anton-l committed
231
232
    # def __repr__(self):
    #    return f"{self.__class__.__name__} {self.to_json_string()}"
233

234
235
236
    @property
    def config(self) -> Dict[str, Any]:
        output = copy.deepcopy(self._dict_to_save)
237
238
        return output

239
    def to_json_string(self) -> str:
240
241
242
243
244
245
        """
        Serializes this instance to a JSON string.

        Returns:
            `str`: String containing all the attributes that make up this configuration instance in JSON format.
        """
246
        config_dict = self._dict_to_save
247
248
        return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

249
    def to_json_file(self, json_file_path: Union[str, os.PathLike]):
250
251
252
253
254
255
256
257
        """
        Save this instance to a JSON file.

        Args:
            json_file_path (`str` or `os.PathLike`):
                Path to the JSON file in which this configuration instance's parameters will be saved.
        """
        with open(json_file_path, "w", encoding="utf-8") as writer:
258
            writer.write(self.to_json_string())