Unverified Commit 76696dca authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Model Card] standardize dreambooth model card (#6729)

* feat: standarize model card creation for dreambooth training.

* correct 'inference

* remove comments.

* take component out of kwargs

* style

* add: card template to have a leaner description.

* widget support.

* propagate changes to train_dreambooth_lora

* propagate changes to custom diffusion

* make widget properly type-annotated
parent 17612de4
...@@ -58,6 +58,7 @@ from diffusers.models.attention_processor import ( ...@@ -58,6 +58,7 @@ from diffusers.models.attention_processor import (
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -78,21 +79,7 @@ def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_ ...@@ -78,21 +79,7 @@ def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_
image.save(os.path.join(repo_folder, f"image_{i}.png")) image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n" img_str += f"![img_{i}](./image_{i}.png)\n"
yaml = f""" model_description = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
instance_prompt: {prompt}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
- custom-diffusion
inference: true
---
"""
model_card = f"""
# Custom Diffusion - {repo_id} # Custom Diffusion - {repo_id}
These are Custom Diffusion adaption weights for {base_model}. The weights were trained on {prompt} using [Custom Diffusion](https://www.cs.cmu.edu/~custom-diffusion). You can find some example images in the following. \n These are Custom Diffusion adaption weights for {base_model}. The weights were trained on {prompt} using [Custom Diffusion](https://www.cs.cmu.edu/~custom-diffusion). You can find some example images in the following. \n
...@@ -100,8 +87,20 @@ These are Custom Diffusion adaption weights for {base_model}. The weights were t ...@@ -100,8 +87,20 @@ These are Custom Diffusion adaption weights for {base_model}. The weights were t
\nFor more details on the training, please follow [this link](https://github.com/huggingface/diffusers/blob/main/examples/custom_diffusion). \nFor more details on the training, please follow [this link](https://github.com/huggingface/diffusers/blob/main/examples/custom_diffusion).
""" """
with open(os.path.join(repo_folder, "README.md"), "w") as f: model_card = load_or_create_model_card(
f.write(yaml + model_card) repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=base_model,
instance_prompt=prompt,
model_description=model_description,
inference=True,
)
tags = ["text-to-image", "diffusers", "stable-diffusion", "stable-diffusion-diffusers", "custom-diffusion"]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
......
...@@ -54,6 +54,7 @@ from diffusers import ( ...@@ -54,6 +54,7 @@ from diffusers import (
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module from diffusers.utils.torch_utils import is_compiled_module
...@@ -69,33 +70,20 @@ logger = get_logger(__name__) ...@@ -69,33 +70,20 @@ logger = get_logger(__name__)
def save_model_card( def save_model_card(
repo_id: str, repo_id: str,
images=None, images: list = None,
base_model=str, base_model: str = None,
train_text_encoder=False, train_text_encoder=False,
prompt=str, prompt: str = None,
repo_folder=None, repo_folder: str = None,
pipeline: DiffusionPipeline = None, pipeline: DiffusionPipeline = None,
): ):
img_str = "" img_str = ""
for i, image in enumerate(images): if images is not None:
image.save(os.path.join(repo_folder, f"image_{i}.png")) for i, image in enumerate(images):
img_str += f"![img_{i}](./image_{i}.png)\n" image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
yaml = f"""
--- model_description = f"""
license: creativeml-openrail-m
base_model: {base_model}
instance_prompt: {prompt}
tags:
- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
- text-to-image
- diffusers
- dreambooth
inference: true
---
"""
model_card = f"""
# DreamBooth - {repo_id} # DreamBooth - {repo_id}
This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
...@@ -104,8 +92,24 @@ You can find some example images in the following. \n ...@@ -104,8 +92,24 @@ You can find some example images in the following. \n
DreamBooth for the text encoder was enabled: {train_text_encoder}. DreamBooth for the text encoder was enabled: {train_text_encoder}.
""" """
with open(os.path.join(repo_folder, "README.md"), "w") as f: model_card = load_or_create_model_card(
f.write(yaml + model_card) repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=base_model,
instance_prompt=prompt,
model_description=model_description,
inference=True,
)
tags = ["text-to-image", "dreambooth"]
if isinstance(pipeline, StableDiffusionPipeline):
tags.extend(["stable-diffusion", "stable-diffusion-diffusers"])
else:
tags.extend(["if", "if-diffusers"])
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def log_validation( def log_validation(
......
...@@ -61,6 +61,7 @@ from diffusers.utils import ( ...@@ -61,6 +61,7 @@ from diffusers.utils import (
convert_unet_state_dict_to_peft, convert_unet_state_dict_to_peft,
is_wandb_available, is_wandb_available,
) )
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module from diffusers.utils.torch_utils import is_compiled_module
...@@ -85,21 +86,7 @@ def save_model_card( ...@@ -85,21 +86,7 @@ def save_model_card(
image.save(os.path.join(repo_folder, f"image_{i}.png")) image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n" img_str += f"![img_{i}](./image_{i}.png)\n"
yaml = f""" model_description = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
instance_prompt: {prompt}
tags:
- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
- text-to-image
- diffusers
- lora
inference: true
---
"""
model_card = f"""
# LoRA DreamBooth - {repo_id} # LoRA DreamBooth - {repo_id}
These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
...@@ -107,8 +94,23 @@ These are LoRA adaption weights for {base_model}. The weights were trained on {p ...@@ -107,8 +94,23 @@ These are LoRA adaption weights for {base_model}. The weights were trained on {p
LoRA for the text encoder was enabled: {train_text_encoder}. LoRA for the text encoder was enabled: {train_text_encoder}.
""" """
with open(os.path.join(repo_folder, "README.md"), "w") as f: model_card = load_or_create_model_card(
f.write(yaml + model_card) repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=base_model,
instance_prompt=prompt,
model_description=model_description,
inference=True,
)
tags = ["text-to-image", "diffusers", "lora"]
if isinstance(pipeline, StableDiffusionPipeline):
tags.extend(["stable-diffusion", "stable-diffusion-diffusers"])
else:
tags.extend(["if", "if-diffusers"])
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
......
...@@ -62,6 +62,7 @@ from diffusers.utils import ( ...@@ -62,6 +62,7 @@ from diffusers.utils import (
convert_unet_state_dict_to_peft, convert_unet_state_dict_to_peft,
is_wandb_available, is_wandb_available,
) )
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module from diffusers.utils.torch_utils import is_compiled_module
...@@ -75,40 +76,22 @@ logger = get_logger(__name__) ...@@ -75,40 +76,22 @@ logger = get_logger(__name__)
def save_model_card( def save_model_card(
repo_id: str, repo_id: str,
images=None, images=None,
base_model=str, base_model: str = None,
train_text_encoder=False, train_text_encoder=False,
instance_prompt=str, instance_prompt=None,
validation_prompt=str, validation_prompt=None,
repo_folder=None, repo_folder=None,
vae_path=None, vae_path=None,
): ):
img_str = "widget:\n" if images else "" widget_dict = []
for i, image in enumerate(images): if images is not None:
image.save(os.path.join(repo_folder, f"image_{i}.png")) for i, image in enumerate(images):
img_str += f""" image.save(os.path.join(repo_folder, f"image_{i}.png"))
- text: '{validation_prompt if validation_prompt else ' ' }' widget_dict.append(
output: {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
url: )
"image_{i}.png"
"""
yaml = f"""
---
tags:
- stable-diffusion-xl
- stable-diffusion-xl-diffusers
- text-to-image
- diffusers
- lora
- template:sd-lora
{img_str}
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
---
"""
model_card = f""" model_description = f"""
# SDXL LoRA DreamBooth - {repo_id} # SDXL LoRA DreamBooth - {repo_id}
<Gallery /> <Gallery />
...@@ -134,8 +117,27 @@ Weights for this model are available in Safetensors format. ...@@ -134,8 +117,27 @@ Weights for this model are available in Safetensors format.
[Download]({repo_id}/tree/main) them in the Files & versions tab. [Download]({repo_id}/tree/main) them in the Files & versions tab.
""" """
with open(os.path.join(repo_folder, "README.md"), "w") as f: model_card = load_or_create_model_card(
f.write(yaml + model_card) repo_id_or_path=repo_id,
from_training=True,
license="openrail++",
base_model=base_model,
instance_prompt=instance_prompt,
model_description=model_description,
widget=widget_dict,
)
tags = [
"text-to-image",
"stable-diffusion-xl",
"stable-diffusion-xl-diffusers",
"text-to-image",
"diffusers",
"lora",
"template:sd-lora",
]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def import_model_class_from_model_name_or_path( def import_model_class_from_model_name_or_path(
......
...@@ -21,7 +21,7 @@ import tempfile ...@@ -21,7 +21,7 @@ import tempfile
import traceback import traceback
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, List, Optional, Union
from uuid import uuid4 from uuid import uuid4
from huggingface_hub import ( from huggingface_hub import (
...@@ -65,7 +65,7 @@ from .logging import get_logger ...@@ -65,7 +65,7 @@ from .logging import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md"
SESSION_ID = uuid4().hex SESSION_ID = uuid4().hex
...@@ -94,43 +94,87 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: ...@@ -94,43 +94,87 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
def load_or_create_model_card( def load_or_create_model_card(
repo_id_or_path: Optional[str] = None, token: Optional[str] = None, is_pipeline: bool = False repo_id_or_path: str = None,
token: Optional[str] = None,
is_pipeline: bool = False,
from_training: bool = False,
model_description: Optional[str] = None,
base_model: str = None,
prompt: Optional[str] = None,
license: Optional[str] = None,
widget: Optional[List[dict]] = None,
inference: Optional[bool] = None,
) -> ModelCard: ) -> ModelCard:
""" """
Loads or creates a model card. Loads or creates a model card.
Args: Args:
repo_id (`str`): repo_id_or_path (`str`):
The repo_id where to look for the model card. The repo id (e.g., "runwayml/stable-diffusion-v1-5") or local path where to look for the model card.
token (`str`, *optional*): token (`str`, *optional*):
Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more details. Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more details.
is_pipeline (`bool`, *optional*): is_pipeline (`bool`):
Boolean to indicate if we're adding tag to a [`DiffusionPipeline`]. Boolean to indicate if we're adding tag to a [`DiffusionPipeline`].
from_training: (`bool`): Boolean flag to denote if the model card is being created from a training script.
model_description (`str`, *optional*): Model description to add to the model card. Helpful when using
`load_or_create_model_card` from a training script.
base_model (`str`): Base model identifier (e.g., "stabilityai/stable-diffusion-xl-base-1.0"). Useful
for DreamBooth-like training.
prompt (`str`, *optional*): Prompt used for training. Useful for DreamBooth-like training.
license: (`str`, *optional*): License of the output artifact. Helpful when using
`load_or_create_model_card` from a training script.
widget (`List[dict]`, *optional*): Widget to accompany a gallery template.
inference: (`bool`, optional): Whether to turn on inference widget. Helpful when using
`load_or_create_model_card` from a training script.
""" """
if not is_jinja_available(): if not is_jinja_available():
raise ValueError( raise ValueError(
"Modelcard rendering is based on Jinja templates." "Modelcard rendering is based on Jinja templates."
" Please make sure to have `jinja` installed before using `create_model_card`." " Please make sure to have `jinja` installed before using `load_or_create_model_card`."
" To install it, please run `pip install Jinja2`." " To install it, please run `pip install Jinja2`."
) )
try: try:
# Check if the model card is present on the remote repo # Check if the model card is present on the remote repo
model_card = ModelCard.load(repo_id_or_path, token=token) model_card = ModelCard.load(repo_id_or_path, token=token)
except EntryNotFoundError: except (EntryNotFoundError, RepositoryNotFoundError):
# Otherwise create a simple model card from template # Otherwise create a model card from template
component = "pipeline" if is_pipeline else "model" if from_training:
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated." model_card = ModelCard.from_template(
card_data = ModelCardData() card_data=ModelCardData( # Card metadata object that will be converted to YAML block
model_card = ModelCard.from_template(card_data, model_description=model_description) license=license,
library_name="diffusers",
inference=inference,
base_model=base_model,
instance_prompt=prompt,
widget=widget,
),
template_path=MODEL_CARD_TEMPLATE_PATH,
model_description=model_description,
)
else:
card_data = ModelCardData()
component = "pipeline" if is_pipeline else "model"
if model_description is None:
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
model_card = ModelCard.from_template(card_data, model_description=model_description)
return model_card return model_card
def populate_model_card(model_card: ModelCard) -> ModelCard: def populate_model_card(model_card: ModelCard, tags: Union[str, List[str]] = None) -> ModelCard:
"""Populates the `model_card` with library name.""" """Populates the `model_card` with library name and optional tags."""
if model_card.data.library_name is None: if model_card.data.library_name is None:
model_card.data.library_name = "diffusers" model_card.data.library_name = "diffusers"
if tags is not None:
if isinstance(tags, str):
tags = [tags]
if model_card.data.tags is None:
model_card.data.tags = []
for tag in tags:
model_card.data.tags.append(tag)
return model_card return model_card
......
---
{{ card_data }}
---
<!-- This model card has been generated automatically according to the information the training script had access to. You
should probably proofread and complete it, then remove this comment. -->
{{ model_description }}
## Intended uses & limitations
#### How to use
```python
# TODO: add an example code snippet for running this diffusion pipeline
```
#### Limitations and bias
[TODO: provide examples of latent issues and potential remediations]
## Training details
[TODO: describe the data used to train the model]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment