hub_utils.py 7.41 KB
Newer Older
anton-l's avatar
anton-l committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


anton-l's avatar
anton-l committed
17
18
import os
import shutil
anton-l's avatar
anton-l committed
19
20
21
from pathlib import Path
from typing import Optional

anton-l's avatar
anton-l committed
22
from diffusers import DiffusionPipeline
anton-l's avatar
anton-l committed
23
from huggingface_hub import HfFolder, Repository, whoami
anton-l's avatar
anton-l committed
24
25

from .utils import is_modelcards_available, logging
26
27
28
29


if is_modelcards_available():
    from modelcards import CardData, ModelCard
anton-l's avatar
anton-l committed
30

anton-l's avatar
anton-l committed
31
32
33
34

logger = logging.get_logger(__name__)


35
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
anton-l's avatar
anton-l committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50


def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
    if token is None:
        token = HfFolder.get_token()
    if organization is None:
        username = whoami(token)["name"]
        return f"{username}/{model_id}"
    else:
        return f"{organization}/{model_id}"


def init_git_repo(args, at_init: bool = False):
    """
    Args:
Patrick von Platen's avatar
Patrick von Platen committed
51
    Initializes a git repo in `args.hub_model_id`.
anton-l's avatar
anton-l committed
52
        at_init (`bool`, *optional*, defaults to `False`):
Patrick von Platen's avatar
Patrick von Platen committed
53
54
            Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
            and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
anton-l's avatar
anton-l committed
55
    """
56
    if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
anton-l's avatar
anton-l committed
57
        return
58
59
60
    hub_token = args.hub_token if hasattr(args, "hub_token") else None
    use_auth_token = True if hub_token is None else hub_token
    if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
anton-l's avatar
anton-l committed
61
62
63
64
        repo_name = Path(args.output_dir).absolute().name
    else:
        repo_name = args.hub_model_id
    if "/" not in repo_name:
65
        repo_name = get_full_repo_name(repo_name, token=hub_token)
anton-l's avatar
anton-l committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

    try:
        repo = Repository(
            args.output_dir,
            clone_from=repo_name,
            use_auth_token=use_auth_token,
            private=args.hub_private_repo,
        )
    except EnvironmentError:
        if args.overwrite_output_dir and at_init:
            # Try again after wiping output_dir
            shutil.rmtree(args.output_dir)
            repo = Repository(
                args.output_dir,
                clone_from=repo_name,
                use_auth_token=use_auth_token,
            )
        else:
            raise

    repo.git_pull()

    # By default, ignore the checkpoint folders
anton-l's avatar
anton-l committed
89
    if not os.path.exists(os.path.join(args.output_dir, ".gitignore")):
anton-l's avatar
anton-l committed
90
91
92
93
94
95
        with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
            writer.writelines(["checkpoint-*/"])

    return repo


anton-l's avatar
anton-l committed
96
97
98
99
100
101
102
103
def push_to_hub(
    args,
    pipeline: DiffusionPipeline,
    repo: Repository,
    commit_message: Optional[str] = "End of training",
    blocking: bool = True,
    **kwargs,
) -> str:
anton-l's avatar
anton-l committed
104
105
    """
    Parameters:
Patrick von Platen's avatar
Patrick von Platen committed
106
    Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
anton-l's avatar
anton-l committed
107
108
109
110
111
112
113
        commit_message (`str`, *optional*, defaults to `"End of training"`):
            Message to commit while pushing.
        blocking (`bool`, *optional*, defaults to `True`):
            Whether the function should return only when the `git push` has finished.
        kwargs:
            Additional keyword arguments passed along to [`create_model_card`].
    Returns:
Patrick von Platen's avatar
Patrick von Platen committed
114
115
        The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the
        commit and an object to track the progress of the commit if `blocking=True`
anton-l's avatar
anton-l committed
116
117
    """

118
    if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
anton-l's avatar
anton-l committed
119
120
121
122
123
124
125
126
127
128
        model_name = Path(args.output_dir).name
    else:
        model_name = args.hub_model_id.split("/")[-1]

    output_dir = args.output_dir
    os.makedirs(output_dir, exist_ok=True)
    logger.info(f"Saving pipeline checkpoint to {output_dir}")
    pipeline.save_pretrained(output_dir)

    # Only push from one node.
129
    if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
anton-l's avatar
anton-l committed
130
131
132
        return

    # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
anton-l's avatar
anton-l committed
133
134
135
136
137
138
    if (
        blocking
        and len(repo.command_queue) > 0
        and repo.command_queue[-1] is not None
        and not repo.command_queue[-1].is_done
    ):
anton-l's avatar
anton-l committed
139
140
        repo.command_queue[-1]._process.kill()

anton-l's avatar
anton-l committed
141
    git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True)
anton-l's avatar
anton-l committed
142
143
144
    # push separately the model card to be independent from the rest of the model
    create_model_card(args, model_name=model_name)
    try:
anton-l's avatar
anton-l committed
145
        repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True)
anton-l's avatar
anton-l committed
146
147
148
149
150
151
152
    except EnvironmentError as exc:
        logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")

    return git_head_commit_url


def create_model_card(args, model_name):
153
154
155
156
157
158
    if not is_modelcards_available:
        raise ValueError(
            "Please make sure to have `modelcards` installed when using the `create_model_card` function. You can"
            " install the package with `pip install modelcards`."
        )

159
    if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
anton-l's avatar
anton-l committed
160
161
        return

162
163
    hub_token = args.hub_token if hasattr(args, "hub_token") else None
    repo_name = get_full_repo_name(model_name, token=hub_token)
164
165
166
167
168
169
170
171
172
173
174
175
176

    model_card = ModelCard.from_template(
        card_data=CardData(  # Card metadata object that will be converted to YAML block
            language="en",
            license="apache-2.0",
            library_name="diffusers",
            tags=[],
            datasets=args.dataset,
            metrics=[],
        ),
        template_path=MODEL_CARD_TEMPLATE_PATH,
        model_name=model_name,
        repo_name=repo_name,
177
        dataset_name=args.dataset if hasattr(args, "dataset") else None,
178
179
180
        learning_rate=args.learning_rate,
        train_batch_size=args.train_batch_size,
        eval_batch_size=args.eval_batch_size,
181
182
183
184
185
186
187
188
189
190
191
192
        gradient_accumulation_steps=args.gradient_accumulation_steps
        if hasattr(args, "gradient_accumulation_steps")
        else None,
        adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
        adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
        adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
        adam_epsilon=args.adam_epsilon if hasattr(args, "adam_weight_decay") else None,
        lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
        lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
        ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
        ema_power=args.ema_power if hasattr(args, "ema_power") else None,
        ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
193
194
195
196
197
        mixed_precision=args.mixed_precision,
    )

    card_path = os.path.join(args.output_dir, "README.md")
    model_card.save(card_path)