Unverified Commit d4c7ab7b authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Hub] feat: explicitly tag to diffusers when using push_to_hub (#6678)



* feat: explicitly tag to diffusers when using push_to_hub

* remove tags.

* reset repo.

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix: tests

* fix: push_to_hub behaviour for tagging from save_pretrained

* Apply suggestions from code review
Co-authored-by: default avatarLucain <lucainp@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarLucain <lucainp@gmail.com>

* import fixes.

* add library name to existing model card.

* add: standalone test for generate_model_card

* fix tests for standalone method

* moved library_name to a better place.

* merge create_model_card and generate_model_card.

* fix test

* address lucain's comments

* fix return identation

* Apply suggestions from code review
Co-authored-by: default avatarLucain <lucainp@gmail.com>

* address further comments.

* Update src/diffusers/pipelines/pipeline_utils.py
Co-authored-by: default avatarLucain <lucainp@gmail.com>

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarLucain <lucainp@gmail.com>
parent ea9dc3fa
......@@ -42,7 +42,7 @@ from ..utils import (
is_torch_version,
logging,
)
from ..utils.hub_utils import PushToHubMixin
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
logger = logging.get_logger(__name__)
......@@ -377,6 +377,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
if push_to_hub:
# Create a new empty model card and eventually tag it
model_card = load_or_create_model_card(repo_id, token=token)
model_card = populate_model_card(model_card)
model_card.save(os.path.join(save_directory, "README.md"))
self._upload_folder(
save_directory,
repo_id,
......
......@@ -60,6 +60,7 @@ from ..utils import (
logging,
numpy_to_pil,
)
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
from ..utils.torch_utils import is_compiled_module
......@@ -725,6 +726,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
self.save_config(save_directory)
if push_to_hub:
# Create a new empty model card and eventually tag it
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
model_card = populate_model_card(model_card)
model_card.save(os.path.join(save_directory, "README.md"))
self._upload_folder(
save_directory,
repo_id,
......
......@@ -28,7 +28,6 @@ from huggingface_hub import (
ModelCard,
ModelCardData,
create_repo,
get_full_repo_name,
hf_hub_download,
upload_folder,
)
......@@ -67,7 +66,6 @@ from .logging import get_logger
logger = get_logger(__name__)
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md"
SESSION_ID = uuid4().hex
......@@ -95,7 +93,20 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
return ua
def create_model_card(args, model_name):
def load_or_create_model_card(
repo_id_or_path: Optional[str] = None, token: Optional[str] = None, is_pipeline: bool = False
) -> ModelCard:
"""
Loads or creates a model card.
Args:
repo_id (`str`):
The repo_id where to look for the model card.
token (`str`, *optional*):
Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more details.
is_pipeline (`bool`, *optional*):
Boolean to indicate if we're adding tag to a [`DiffusionPipeline`].
"""
if not is_jinja_available():
raise ValueError(
"Modelcard rendering is based on Jinja templates."
......@@ -103,45 +114,24 @@ def create_model_card(args, model_name):
" To install it, please run `pip install Jinja2`."
)
if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
return
hub_token = args.hub_token if hasattr(args, "hub_token") else None
repo_name = get_full_repo_name(model_name, token=hub_token)
model_card = ModelCard.from_template(
card_data=ModelCardData( # Card metadata object that will be converted to YAML block
language="en",
license="apache-2.0",
library_name="diffusers",
tags=[],
datasets=args.dataset_name,
metrics=[],
),
template_path=MODEL_CARD_TEMPLATE_PATH,
model_name=model_name,
repo_name=repo_name,
dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
learning_rate=args.learning_rate,
train_batch_size=args.train_batch_size,
eval_batch_size=args.eval_batch_size,
gradient_accumulation_steps=(
args.gradient_accumulation_steps if hasattr(args, "gradient_accumulation_steps") else None
),
adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
ema_power=args.ema_power if hasattr(args, "ema_power") else None,
ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
mixed_precision=args.mixed_precision,
)
card_path = os.path.join(args.output_dir, "README.md")
model_card.save(card_path)
try:
# Check if the model card is present on the remote repo
model_card = ModelCard.load(repo_id_or_path, token=token)
except EntryNotFoundError:
# Otherwise create a simple model card from template
component = "pipeline" if is_pipeline else "model"
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
card_data = ModelCardData()
model_card = ModelCard.from_template(card_data, model_description=model_description)
return model_card
def populate_model_card(model_card: ModelCard) -> ModelCard:
"""Populates the `model_card` with library name."""
if model_card.data.library_name is None:
model_card.data.library_name = "diffusers"
return model_card
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None):
......@@ -435,6 +425,10 @@ class PushToHubMixin:
"""
repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id
# Create a new empty model card and eventually tag it
model_card = load_or_create_model_card(repo_id, token=token)
model_card = populate_model_card(model_card)
# Save all files.
save_kwargs = {"safe_serialization": safe_serialization}
if "Scheduler" not in self.__class__.__name__:
......@@ -443,6 +437,9 @@ class PushToHubMixin:
with tempfile.TemporaryDirectory() as tmpdir:
self.save_pretrained(tmpdir, **save_kwargs)
# Update model card if needed:
model_card.save(os.path.join(tmpdir, "README.md"))
return self._upload_folder(
tmpdir,
repo_id,
......
---
{{ card_data }}
---
<!-- This model card has been generated automatically according to the information the training script had access to. You
should probably proofread and complete it, then remove this comment. -->
# {{ model_name | default("Diffusion Model") }}
## Model description
This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library
on the `{{ dataset_name }}` dataset.
## Intended uses & limitations
#### How to use
```python
# TODO: add an example code snippet for running this diffusion pipeline
```
#### Limitations and bias
[TODO: provide examples of latent issues and potential remediations]
## Training data
[TODO: describe the data used to train the model]
### Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: {{ learning_rate }}
- train_batch_size: {{ train_batch_size }}
- eval_batch_size: {{ eval_batch_size }}
- gradient_accumulation_steps: {{ gradient_accumulation_steps }}
- optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }}
- lr_scheduler: {{ lr_scheduler }}
- lr_warmup_steps: {{ lr_warmup_steps }}
- ema_inv_gamma: {{ ema_inv_gamma }}
- ema_inv_gamma: {{ ema_power }}
- ema_inv_gamma: {{ ema_max_decay }}
- mixed_precision: {{ mixed_precision }}
### Training results
📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars)
......@@ -24,7 +24,8 @@ from typing import Dict, List, Tuple
import numpy as np
import requests_mock
import torch
from huggingface_hub import delete_repo
from huggingface_hub import ModelCard, delete_repo
from huggingface_hub.utils import is_jinja_available
from requests.exceptions import HTTPError
from diffusers.models import UNet2DConditionModel
......@@ -732,3 +733,26 @@ class ModelPushToHubTester(unittest.TestCase):
# Reset repo
delete_repo(self.org_repo_id, token=TOKEN)
@unittest.skipIf(
not is_jinja_available(),
reason="Model card tests cannot be performed without Jinja installed.",
)
def test_push_to_hub_library_name(self):
model = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
model.push_to_hub(self.repo_id, token=TOKEN)
model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data
assert model_card.library_name == "diffusers"
# Reset repo
delete_repo(self.repo_id, token=TOKEN)
......@@ -15,37 +15,15 @@
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest.mock import Mock, patch
import diffusers.utils.hub_utils
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
class CreateModelCardTest(unittest.TestCase):
@patch("diffusers.utils.hub_utils.get_full_repo_name")
def test_create_model_card(self, repo_name_mock: Mock) -> None:
repo_name_mock.return_value = "full_repo_name"
def test_generate_model_card_with_library_name(self):
with TemporaryDirectory() as tmpdir:
# Dummy args values
args = Mock()
args.output_dir = tmpdir
args.local_rank = 0
args.hub_token = "hub_token"
args.dataset_name = "dataset_name"
args.learning_rate = 0.01
args.train_batch_size = 100000
args.eval_batch_size = 10000
args.gradient_accumulation_steps = 0.01
args.adam_beta1 = 0.02
args.adam_beta2 = 0.03
args.adam_weight_decay = 0.0005
args.adam_epsilon = 0.000001
args.lr_scheduler = 1
args.lr_warmup_steps = 10
args.ema_inv_gamma = 0.001
args.ema_power = 0.1
args.ema_max_decay = 0.2
args.mixed_precision = True
# Model card mush be rendered and saved
diffusers.utils.hub_utils.create_model_card(args, model_name="model_name")
self.assertTrue((Path(tmpdir) / "README.md").is_file())
file_path = Path(tmpdir) / "README.md"
file_path.write_text("---\nlibrary_name: foo\n---\nContent\n")
model_card = load_or_create_model_card(file_path)
populate_model_card(model_card)
assert model_card.data.library_name == "foo"
......@@ -13,7 +13,8 @@ from typing import Callable, Union
import numpy as np
import PIL.Image
import torch
from huggingface_hub import delete_repo
from huggingface_hub import ModelCard, delete_repo
from huggingface_hub.utils import is_jinja_available
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
import diffusers
......@@ -1142,6 +1143,21 @@ class PipelinePushToHubTester(unittest.TestCase):
# Reset repo
delete_repo(self.org_repo_id, token=TOKEN)
@unittest.skipIf(
not is_jinja_available(),
reason="Model card tests cannot be performed without Jinja installed.",
)
def test_push_to_hub_library_name(self):
components = self.get_pipeline_components()
pipeline = StableDiffusionPipeline(**components)
pipeline.push_to_hub(self.repo_id, token=TOKEN)
model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data
assert model_card.library_name == "diffusers"
# Reset repo
delete_repo(self.repo_id, token=TOKEN)
# For SDXL and its derivative pipelines (such as ControlNet), we have the text encoders
# and the tokenizers as optional components. So, we need to override the `test_save_load_optional_components()`
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment