Unverified Commit 76696dca authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Model Card] standardize dreambooth model card (#6729)

* feat: standarize model card creation for dreambooth training.

* correct 'inference

* remove comments.

* take component out of kwargs

* style

* add: card template to have a leaner description.

* widget support.

* propagate changes to train_dreambooth_lora

* propagate changes to custom diffusion

* make widget properly type-annotated
parent 17612de4
......@@ -58,6 +58,7 @@ from diffusers.models.attention_processor import (
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
......@@ -78,21 +79,7 @@ def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
instance_prompt: {prompt}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
- custom-diffusion
inference: true
---
"""
model_card = f"""
model_description = f"""
# Custom Diffusion - {repo_id}
These are Custom Diffusion adaption weights for {base_model}. The weights were trained on {prompt} using [Custom Diffusion](https://www.cs.cmu.edu/~custom-diffusion). You can find some example images in the following. \n
......@@ -100,8 +87,20 @@ These are Custom Diffusion adaption weights for {base_model}. The weights were t
\nFor more details on the training, please follow [this link](https://github.com/huggingface/diffusers/blob/main/examples/custom_diffusion).
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=base_model,
instance_prompt=prompt,
model_description=model_description,
inference=True,
)
tags = ["text-to-image", "diffusers", "stable-diffusion", "stable-diffusion-diffusers", "custom-diffusion"]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
......
......@@ -54,6 +54,7 @@ from diffusers import (
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
......@@ -69,33 +70,20 @@ logger = get_logger(__name__)
def save_model_card(
repo_id: str,
images=None,
base_model=str,
images: list = None,
base_model: str = None,
train_text_encoder=False,
prompt=str,
repo_folder=None,
prompt: str = None,
repo_folder: str = None,
pipeline: DiffusionPipeline = None,
):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
instance_prompt: {prompt}
tags:
- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
- text-to-image
- diffusers
- dreambooth
inference: true
---
"""
model_card = f"""
if images is not None:
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
model_description = f"""
# DreamBooth - {repo_id}
This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
......@@ -104,8 +92,24 @@ You can find some example images in the following. \n
DreamBooth for the text encoder was enabled: {train_text_encoder}.
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=base_model,
instance_prompt=prompt,
model_description=model_description,
inference=True,
)
tags = ["text-to-image", "dreambooth"]
if isinstance(pipeline, StableDiffusionPipeline):
tags.extend(["stable-diffusion", "stable-diffusion-diffusers"])
else:
tags.extend(["if", "if-diffusers"])
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def log_validation(
......
......@@ -61,6 +61,7 @@ from diffusers.utils import (
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
......@@ -85,21 +86,7 @@ def save_model_card(
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
instance_prompt: {prompt}
tags:
- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
- text-to-image
- diffusers
- lora
inference: true
---
"""
model_card = f"""
model_description = f"""
# LoRA DreamBooth - {repo_id}
These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
......@@ -107,8 +94,23 @@ These are LoRA adaption weights for {base_model}. The weights were trained on {p
LoRA for the text encoder was enabled: {train_text_encoder}.
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=base_model,
instance_prompt=prompt,
model_description=model_description,
inference=True,
)
tags = ["text-to-image", "diffusers", "lora"]
if isinstance(pipeline, StableDiffusionPipeline):
tags.extend(["stable-diffusion", "stable-diffusion-diffusers"])
else:
tags.extend(["if", "if-diffusers"])
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
......
......@@ -62,6 +62,7 @@ from diffusers.utils import (
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
......@@ -75,40 +76,22 @@ logger = get_logger(__name__)
def save_model_card(
repo_id: str,
images=None,
base_model=str,
base_model: str = None,
train_text_encoder=False,
instance_prompt=str,
validation_prompt=str,
instance_prompt=None,
validation_prompt=None,
repo_folder=None,
vae_path=None,
):
img_str = "widget:\n" if images else ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"""
- text: '{validation_prompt if validation_prompt else ' ' }'
output:
url:
"image_{i}.png"
"""
yaml = f"""
---
tags:
- stable-diffusion-xl
- stable-diffusion-xl-diffusers
- text-to-image
- diffusers
- lora
- template:sd-lora
{img_str}
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
---
"""
widget_dict = []
if images is not None:
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
widget_dict.append(
{"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
)
model_card = f"""
model_description = f"""
# SDXL LoRA DreamBooth - {repo_id}
<Gallery />
......@@ -134,8 +117,27 @@ Weights for this model are available in Safetensors format.
[Download]({repo_id}/tree/main) them in the Files & versions tab.
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="openrail++",
base_model=base_model,
instance_prompt=instance_prompt,
model_description=model_description,
widget=widget_dict,
)
tags = [
"text-to-image",
"stable-diffusion-xl",
"stable-diffusion-xl-diffusers",
"text-to-image",
"diffusers",
"lora",
"template:sd-lora",
]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def import_model_class_from_model_name_or_path(
......
......@@ -21,7 +21,7 @@ import tempfile
import traceback
import warnings
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, List, Optional, Union
from uuid import uuid4
from huggingface_hub import (
......@@ -65,7 +65,7 @@ from .logging import get_logger
logger = get_logger(__name__)
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md"
SESSION_ID = uuid4().hex
......@@ -94,43 +94,87 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
def load_or_create_model_card(
repo_id_or_path: Optional[str] = None, token: Optional[str] = None, is_pipeline: bool = False
repo_id_or_path: str = None,
token: Optional[str] = None,
is_pipeline: bool = False,
from_training: bool = False,
model_description: Optional[str] = None,
base_model: str = None,
prompt: Optional[str] = None,
license: Optional[str] = None,
widget: Optional[List[dict]] = None,
inference: Optional[bool] = None,
) -> ModelCard:
"""
Loads or creates a model card.
Args:
repo_id (`str`):
The repo_id where to look for the model card.
repo_id_or_path (`str`):
The repo id (e.g., "runwayml/stable-diffusion-v1-5") or local path where to look for the model card.
token (`str`, *optional*):
Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more details.
is_pipeline (`bool`, *optional*):
is_pipeline (`bool`):
Boolean to indicate if we're adding tag to a [`DiffusionPipeline`].
from_training: (`bool`): Boolean flag to denote if the model card is being created from a training script.
model_description (`str`, *optional*): Model description to add to the model card. Helpful when using
`load_or_create_model_card` from a training script.
base_model (`str`): Base model identifier (e.g., "stabilityai/stable-diffusion-xl-base-1.0"). Useful
for DreamBooth-like training.
prompt (`str`, *optional*): Prompt used for training. Useful for DreamBooth-like training.
license: (`str`, *optional*): License of the output artifact. Helpful when using
`load_or_create_model_card` from a training script.
widget (`List[dict]`, *optional*): Widget to accompany a gallery template.
inference: (`bool`, optional): Whether to turn on inference widget. Helpful when using
`load_or_create_model_card` from a training script.
"""
if not is_jinja_available():
raise ValueError(
"Modelcard rendering is based on Jinja templates."
" Please make sure to have `jinja` installed before using `create_model_card`."
" Please make sure to have `jinja` installed before using `load_or_create_model_card`."
" To install it, please run `pip install Jinja2`."
)
try:
# Check if the model card is present on the remote repo
model_card = ModelCard.load(repo_id_or_path, token=token)
except EntryNotFoundError:
# Otherwise create a simple model card from template
component = "pipeline" if is_pipeline else "model"
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
card_data = ModelCardData()
model_card = ModelCard.from_template(card_data, model_description=model_description)
except (EntryNotFoundError, RepositoryNotFoundError):
# Otherwise create a model card from template
if from_training:
model_card = ModelCard.from_template(
card_data=ModelCardData( # Card metadata object that will be converted to YAML block
license=license,
library_name="diffusers",
inference=inference,
base_model=base_model,
instance_prompt=prompt,
widget=widget,
),
template_path=MODEL_CARD_TEMPLATE_PATH,
model_description=model_description,
)
else:
card_data = ModelCardData()
component = "pipeline" if is_pipeline else "model"
if model_description is None:
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
model_card = ModelCard.from_template(card_data, model_description=model_description)
return model_card
def populate_model_card(model_card: ModelCard) -> ModelCard:
"""Populates the `model_card` with library name."""
def populate_model_card(model_card: ModelCard, tags: Union[str, List[str]] = None) -> ModelCard:
"""Populates the `model_card` with library name and optional tags."""
if model_card.data.library_name is None:
model_card.data.library_name = "diffusers"
if tags is not None:
if isinstance(tags, str):
tags = [tags]
if model_card.data.tags is None:
model_card.data.tags = []
for tag in tags:
model_card.data.tags.append(tag)
return model_card
......
---
{{ card_data }}
---
<!-- This model card has been generated automatically according to the information the training script had access to. You
should probably proofread and complete it, then remove this comment. -->
{{ model_description }}
## Intended uses & limitations
#### How to use
```python
# TODO: add an example code snippet for running this diffusion pipeline
```
#### Limitations and bias
[TODO: provide examples of latent issues and potential remediations]
## Training details
[TODO: describe the data used to train the model]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment