Unverified Commit d9023a67 authored by Abhipsha Das's avatar Abhipsha Das Committed by GitHub
Browse files

[Model Card] standardize advanced diffusion training sdxl lora (#7615)



* model card gen code

* push modelcard creation

* remove optional from params

* add import

* add use_dora check

* correct lora var use in tags

* make style && make quality

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent c4646a39
...@@ -71,6 +71,7 @@ from diffusers.utils import ( ...@@ -71,6 +71,7 @@ from diffusers.utils import (
convert_unet_state_dict_to_peft, convert_unet_state_dict_to_peft,
is_wandb_available, is_wandb_available,
) )
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module from diffusers.utils.torch_utils import is_compiled_module
...@@ -101,7 +102,7 @@ def determine_scheduler_type(pretrained_model_name_or_path, revision): ...@@ -101,7 +102,7 @@ def determine_scheduler_type(pretrained_model_name_or_path, revision):
def save_model_card( def save_model_card(
repo_id: str, repo_id: str,
use_dora: bool, use_dora: bool,
images=None, images: list = None,
base_model: str = None, base_model: str = None,
train_text_encoder=False, train_text_encoder=False,
train_text_encoder_ti=False, train_text_encoder_ti=False,
...@@ -111,20 +112,17 @@ def save_model_card( ...@@ -111,20 +112,17 @@ def save_model_card(
repo_folder=None, repo_folder=None,
vae_path=None, vae_path=None,
): ):
img_str = "widget:\n"
lora = "lora" if not use_dora else "dora" lora = "lora" if not use_dora else "dora"
widget_dict = []
if images is not None:
for i, image in enumerate(images): for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png")) image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f""" widget_dict.append(
- text: '{validation_prompt if validation_prompt else ' ' }' {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
output: )
url: else:
"image_{i}.png" widget_dict.append({"text": instance_prompt})
"""
if not images:
img_str += f"""
- text: '{instance_prompt}'
"""
embeddings_filename = f"{repo_folder}_emb" embeddings_filename = f"{repo_folder}_emb"
instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1)) instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1))
ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt)) ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt))
...@@ -169,23 +167,7 @@ pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_en ...@@ -169,23 +167,7 @@ pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_en
to trigger concept `{key}` → use `{tokens}` in your prompt \n to trigger concept `{key}` → use `{tokens}` in your prompt \n
""" """
yaml = f"""--- model_description = f"""
tags:
- stable-diffusion-xl
- stable-diffusion-xl-diffusers
- diffusers-training
- text-to-image
- diffusers
- {lora}
- template:sd-lora
{img_str}
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
---
"""
model_card = f"""
# SDXL LoRA DreamBooth - {repo_id} # SDXL LoRA DreamBooth - {repo_id}
<Gallery /> <Gallery />
...@@ -234,8 +216,25 @@ Special VAE used for training: {vae_path}. ...@@ -234,8 +216,25 @@ Special VAE used for training: {vae_path}.
{license} {license}
""" """
with open(os.path.join(repo_folder, "README.md"), "w") as f: model_card = load_or_create_model_card(
f.write(yaml + model_card) repo_id_or_path=repo_id,
from_training=True,
license="openrail++",
base_model=base_model,
prompt=instance_prompt,
model_description=model_description,
widget=widget_dict,
)
tags = [
"text-to-image",
"stable-diffusion-xl",
"stable-diffusion-xl-diffusers",
"text-to-image",
"diffusers",
lora,
"template:sd-lora",
]
model_card = populate_model_card(model_card, tags=tags)
def log_validation( def log_validation(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment