hub_utils.py 5.57 KB
Newer Older
anton-l's avatar
anton-l committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


anton-l's avatar
anton-l committed
17
import os
18
import sys
anton-l's avatar
anton-l committed
19
from pathlib import Path
20
21
from typing import Dict, Optional, Union
from uuid import uuid4
anton-l's avatar
anton-l committed
22

23
import requests
24
from huggingface_hub import HfFolder, whoami
anton-l's avatar
anton-l committed
25

26
from . import __version__
27
from .utils import ENV_VARS_TRUE_VALUES, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
28
29
30
31
32
33
34
35
36
37
from .utils.import_utils import (
    _flax_version,
    _jax_version,
    _onnxruntime_version,
    _torch_version,
    is_flax_available,
    is_modelcards_available,
    is_onnx_available,
    is_torch_available,
)
38
39
40
41


if is_modelcards_available():
    from modelcards import CardData, ModelCard
anton-l's avatar
anton-l committed
42

anton-l's avatar
anton-l committed
43
44
45
46

logger = logging.get_logger(__name__)


47
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
48
SESSION_ID = uuid4().hex
49
HF_HUB_OFFLINE = os.getenv("HF_HUB_OFFLINE", "").upper() in ENV_VARS_TRUE_VALUES
50
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES
51
HUGGINGFACE_CO_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/"
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75


def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
    """
    Formats a user-agent string with basic info about a request.
    """
    ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
    if DISABLE_TELEMETRY:
        return ua + "; telemetry/off"
    if is_torch_available():
        ua += f"; torch/{_torch_version}"
    if is_flax_available():
        ua += f"; jax/{_jax_version}"
        ua += f"; flax/{_flax_version}"
    if is_onnx_available():
        ua += f"; onnxruntime/{_onnxruntime_version}"
    # CI will set this value to True
    if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
        ua += "; is_ci/true"
    if isinstance(user_agent, dict):
        ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
    elif isinstance(user_agent, str):
        ua += "; " + user_agent
    return ua
anton-l's avatar
anton-l committed
76
77


78
79
80
81
82
83
84
85
86
def send_telemetry(data: Dict, name: str):
    """
    Sends logs to the Hub telemetry endpoint.

    Args:
        data: the fields to track, e.g. {"example_name": "dreambooth"}
        name: a unique name to differentiate the telemetry logs, e.g. "diffusers_examples" or "diffusers_notebooks"
    """
    if DISABLE_TELEMETRY or HF_HUB_OFFLINE:
87
        return
88
89
90
91
92
93
94
95
96
97
98

    headers = {"user-agent": http_user_agent(data)}
    endpoint = HUGGINGFACE_CO_TELEMETRY + name
    try:
        r = requests.head(endpoint, headers=headers)
        r.raise_for_status()
    except Exception:
        # We don't want to error in case of connection errors of any kind.
        pass


anton-l's avatar
anton-l committed
99
100
101
102
103
104
105
106
107
108
109
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
    if token is None:
        token = HfFolder.get_token()
    if organization is None:
        username = whoami(token)["name"]
        return f"{username}/{model_id}"
    else:
        return f"{organization}/{model_id}"


def create_model_card(args, model_name):
110
111
112
113
114
115
    if not is_modelcards_available:
        raise ValueError(
            "Please make sure to have `modelcards` installed when using the `create_model_card` function. You can"
            " install the package with `pip install modelcards`."
        )

116
    if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
anton-l's avatar
anton-l committed
117
118
        return

119
120
    hub_token = args.hub_token if hasattr(args, "hub_token") else None
    repo_name = get_full_repo_name(model_name, token=hub_token)
121
122
123
124
125
126
127

    model_card = ModelCard.from_template(
        card_data=CardData(  # Card metadata object that will be converted to YAML block
            language="en",
            license="apache-2.0",
            library_name="diffusers",
            tags=[],
128
            datasets=args.dataset_name,
129
130
131
132
133
            metrics=[],
        ),
        template_path=MODEL_CARD_TEMPLATE_PATH,
        model_name=model_name,
        repo_name=repo_name,
134
        dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
135
136
137
        learning_rate=args.learning_rate,
        train_batch_size=args.train_batch_size,
        eval_batch_size=args.eval_batch_size,
138
139
140
141
142
143
        gradient_accumulation_steps=args.gradient_accumulation_steps
        if hasattr(args, "gradient_accumulation_steps")
        else None,
        adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
        adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
        adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
Manuel Romero's avatar
Manuel Romero committed
144
        adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
145
146
147
148
149
        lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
        lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
        ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
        ema_power=args.ema_power if hasattr(args, "ema_power") else None,
        ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
150
151
152
153
154
        mixed_precision=args.mixed_precision,
    )

    card_path = os.path.join(args.output_dir, "README.md")
    model_card.save(card_path)