modelcard.py 9.8 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Configuration base class and utilities."""


import copy
import json
import os

22
23
24
25
from .file_utils import (
    CONFIG_NAME,
    MODEL_CARD_NAME,
    TF2_WEIGHTS_NAME,
Aymeric Augustin's avatar
Aymeric Augustin committed
26
    WEIGHTS_NAME,
27
28
    cached_path,
    hf_bucket_url,
Aymeric Augustin's avatar
Aymeric Augustin committed
29
    is_remote_url,
30
)
Sylvain Gugger's avatar
Sylvain Gugger committed
31
from .models.auto.configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
Lysandre Debut's avatar
Lysandre Debut committed
32
from .utils import logging
thomwolf's avatar
thomwolf committed
33
34


Lysandre Debut's avatar
Lysandre Debut committed
35
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
36
37


38
class ModelCard:
Sylvain Gugger's avatar
Sylvain Gugger committed
39
40
    r"""
    Structured Model Card class. Store model card as well as methods for loading/downloading/saving model cards.
thomwolf's avatar
thomwolf committed
41

Sylvain Gugger's avatar
Sylvain Gugger committed
42
43
44
    Please read the following paper for details and explanation on the sections: "Model Cards for Model Reporting" by
    Margaret Mitchell, Simone Wu, Andrew Zaldivar, Parker Barnes, Lucy Vasserman, Ben Hutchinson, Elena Spitzer,
    Inioluwa Deborah Raji and Timnit Gebru for the proposal behind model cards. Link: https://arxiv.org/abs/1810.03993
thomwolf's avatar
thomwolf committed
45

Sylvain Gugger's avatar
Sylvain Gugger committed
46
    Note: A model card can be loaded and saved to disk.
thomwolf's avatar
thomwolf committed
47

Lysandre's avatar
Lysandre committed
48
    Parameters:
thomwolf's avatar
thomwolf committed
49
    """
50

thomwolf's avatar
thomwolf committed
51
    def __init__(self, **kwargs):
52
        # Recommended attributes from https://arxiv.org/abs/1810.03993 (see papers)
53
54
55
56
57
58
59
60
61
        self.model_details = kwargs.pop("model_details", {})
        self.intended_use = kwargs.pop("intended_use", {})
        self.factors = kwargs.pop("factors", {})
        self.metrics = kwargs.pop("metrics", {})
        self.evaluation_data = kwargs.pop("evaluation_data", {})
        self.training_data = kwargs.pop("training_data", {})
        self.quantitative_analyses = kwargs.pop("quantitative_analyses", {})
        self.ethical_considerations = kwargs.pop("ethical_considerations", {})
        self.caveats_and_recommendations = kwargs.pop("caveats_and_recommendations", {})
thomwolf's avatar
thomwolf committed
62
63
64
65
66
67
68
69
70

        # Open additional attributes
        for key, value in kwargs.items():
            try:
                setattr(self, key, value)
            except AttributeError as err:
                logger.error("Can't set {} with value {} for {}".format(key, value, self))
                raise err

thomwolf's avatar
thomwolf committed
71
    def save_pretrained(self, save_directory_or_file):
Lysandre's avatar
Lysandre committed
72
        """Save a model card object to the directory or file `save_directory_or_file`."""
thomwolf's avatar
thomwolf committed
73
74
75
76
77
        if os.path.isdir(save_directory_or_file):
            # If we save using the predefined names, we can load using `from_pretrained`
            output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME)
        else:
            output_model_card_file = save_directory_or_file
thomwolf's avatar
thomwolf committed
78
79
80
81
82
83

        self.to_json_file(output_model_card_file)
        logger.info("Model card saved in {}".format(output_model_card_file))

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
84
85
        r"""
        Instantiate a :class:`~transformers.ModelCard` from a pre-trained model model card.
thomwolf's avatar
thomwolf committed
86
87
88
89

        Parameters:
            pretrained_model_name_or_path: either:

Sylvain Gugger's avatar
Sylvain Gugger committed
90
91
92
93
94
95
                - a string with the `shortcut name` of a pre-trained model card to load from cache or download, e.g.:
                  ``bert-base-uncased``.
                - a string with the `identifier name` of a pre-trained model card that was user-uploaded to our S3,
                  e.g.: ``dbmdz/bert-base-german-cased``.
                - a path to a `directory` containing a model card file saved using the
                  :func:`~transformers.ModelCard.save_pretrained` method, e.g.: ``./my_model_directory/``.
96
                - a path or url to a saved model card JSON `file`, e.g.: ``./my_model_directory/modelcard.json``.
thomwolf's avatar
thomwolf committed
97
98

            cache_dir: (`optional`) string:
Sylvain Gugger's avatar
Sylvain Gugger committed
99
100
                Path to a directory in which a downloaded pre-trained model card should be cached if the standard cache
                should not be used.
thomwolf's avatar
thomwolf committed
101
102
103

            kwargs: (`optional`) dict: key/value pairs with which to update the ModelCard object after loading.

Sylvain Gugger's avatar
Sylvain Gugger committed
104
105
106
107
                - The values in kwargs of any keys which are model card attributes will be used to override the loaded
                  values.
                - Behavior concerning key/value pairs whose keys are *not* model card attributes is controlled by the
                  `return_unused_kwargs` keyword parameter.
thomwolf's avatar
thomwolf committed
108
109

            proxies: (`optional`) dict, default None:
Sylvain Gugger's avatar
Sylvain Gugger committed
110
111
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.
thomwolf's avatar
thomwolf committed
112

113
            find_from_standard_name: (`optional`) boolean, default True:
Sylvain Gugger's avatar
Sylvain Gugger committed
114
115
116
                If the pretrained_model_name_or_path ends with our standard model or config filenames, replace them
                with our standard modelcard filename. Can be used to directly feed a model/config url and access the
                colocated modelcard.
117

thomwolf's avatar
thomwolf committed
118
119
120
            return_unused_kwargs: (`optional`) bool:

                - If False, then this function returns just the final model card object.
Sylvain Gugger's avatar
Sylvain Gugger committed
121
122
123
                - If True, then this functions returns a tuple `(model card, unused_kwargs)` where `unused_kwargs` is a
                  dictionary consisting of the key/value pairs whose keys are not model card attributes: ie the part of
                  kwargs which has not been used to update `ModelCard` and is otherwise ignored.
thomwolf's avatar
thomwolf committed
124
125
126

        Examples::

127
128
129
            modelcard = ModelCard.from_pretrained('bert-base-uncased')    # Download model card from S3 and cache.
            modelcard = ModelCard.from_pretrained('./test/saved_model/')  # E.g. model card was saved using `save_pretrained('./test/saved_model/')`
            modelcard = ModelCard.from_pretrained('./test/saved_model/modelcard.json')
130
            modelcard = ModelCard.from_pretrained('bert-base-uncased', output_attentions=True, foo=False)
thomwolf's avatar
thomwolf committed
131
132

        """
133
134
135
136
        cache_dir = kwargs.pop("cache_dir", None)
        proxies = kwargs.pop("proxies", None)
        find_from_standard_name = kwargs.pop("find_from_standard_name", True)
        return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
thomwolf's avatar
thomwolf committed
137

138
        if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
139
140
            # For simplicity we use the same pretrained url than the configuration files
            # but with a different suffix (modelcard.json). This suffix is replaced below.
141
            model_card_file = ALL_PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
thomwolf's avatar
thomwolf committed
142
143
144
145
146
        elif os.path.isdir(pretrained_model_name_or_path):
            model_card_file = os.path.join(pretrained_model_name_or_path, MODEL_CARD_NAME)
        elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
            model_card_file = pretrained_model_name_or_path
        else:
Julien Chaumond's avatar
Julien Chaumond committed
147
            model_card_file = hf_bucket_url(pretrained_model_name_or_path, filename=MODEL_CARD_NAME, mirror=None)
thomwolf's avatar
thomwolf committed
148

149
150
151
152
153
        if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
            model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
            model_card_file = model_card_file.replace(WEIGHTS_NAME, MODEL_CARD_NAME)
            model_card_file = model_card_file.replace(TF2_WEIGHTS_NAME, MODEL_CARD_NAME)

thomwolf's avatar
thomwolf committed
154
        try:
thomwolf's avatar
thomwolf committed
155
            # Load from URL or cache if already cached
156
            resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, proxies=proxies)
thomwolf's avatar
thomwolf committed
157
158
159
            if resolved_model_card_file == model_card_file:
                logger.info("loading model card file {}".format(model_card_file))
            else:
160
161
162
                logger.info(
                    "loading model card file {} from cache at {}".format(model_card_file, resolved_model_card_file)
                )
thomwolf's avatar
thomwolf committed
163
            # Load model card
164
            modelcard = cls.from_json_file(resolved_model_card_file)
thomwolf's avatar
thomwolf committed
165

166
        except (EnvironmentError, json.JSONDecodeError):
thomwolf's avatar
thomwolf committed
167
            # We fall back on creating an empty model card
168
            modelcard = cls()
thomwolf's avatar
thomwolf committed
169
170
171
172

        # Update model card with kwargs if needed
        to_remove = []
        for key, value in kwargs.items():
173
174
            if hasattr(modelcard, key):
                setattr(modelcard, key, value)
thomwolf's avatar
thomwolf committed
175
176
177
178
                to_remove.append(key)
        for key in to_remove:
            kwargs.pop(key, None)

179
        logger.info("Model card: %s", str(modelcard))
thomwolf's avatar
thomwolf committed
180
        if return_unused_kwargs:
181
            return modelcard, kwargs
thomwolf's avatar
thomwolf committed
182
        else:
183
            return modelcard
thomwolf's avatar
thomwolf committed
184
185
186
187
188
189
190
191
192

    @classmethod
    def from_dict(cls, json_object):
        """Constructs a `ModelCard` from a Python dictionary of parameters."""
        return cls(**json_object)

    @classmethod
    def from_json_file(cls, json_file):
        """Constructs a `ModelCard` from a json file of parameters."""
193
        with open(json_file, "r", encoding="utf-8") as reader:
thomwolf's avatar
thomwolf committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
            text = reader.read()
        dict_obj = json.loads(text)
        return cls(**dict_obj)

    def __eq__(self, other):
        return self.__dict__ == other.__dict__

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

    def to_json_file(self, json_file_path):
        """ Save this instance to a json file."""
215
        with open(json_file_path, "w", encoding="utf-8") as writer:
thomwolf's avatar
thomwolf committed
216
            writer.write(self.to_json_string())