"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6133d98ff70eafad7b9f65da50a450a965d1957f"
Unverified Commit c7543184 authored by Abhipsha Das's avatar Abhipsha Das Committed by GitHub
Browse files

[Model Card] standardize advanced diffusion training sd15 lora (#7613)



* modelcard generation edit

* add missed tag

* fix param name

* fix var

* change str to dict

* add use_dora check

* use correct tags for lora

* make style && make quality

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent d2e5cb3c
...@@ -67,6 +67,7 @@ from diffusers.utils import ( ...@@ -67,6 +67,7 @@ from diffusers.utils import (
convert_state_dict_to_kohya, convert_state_dict_to_kohya,
is_wandb_available, is_wandb_available,
) )
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -79,30 +80,27 @@ logger = get_logger(__name__) ...@@ -79,30 +80,27 @@ logger = get_logger(__name__)
def save_model_card( def save_model_card(
repo_id: str, repo_id: str,
use_dora: bool, use_dora: bool,
images=None, images: list = None,
base_model=str, base_model: str = None,
train_text_encoder=False, train_text_encoder=False,
train_text_encoder_ti=False, train_text_encoder_ti=False,
token_abstraction_dict=None, token_abstraction_dict=None,
instance_prompt=str, instance_prompt=None,
validation_prompt=str, validation_prompt=None,
repo_folder=None, repo_folder=None,
vae_path=None, vae_path=None,
): ):
img_str = "widget:\n"
lora = "lora" if not use_dora else "dora" lora = "lora" if not use_dora else "dora"
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png")) widget_dict = []
img_str += f""" if images is not None:
- text: '{validation_prompt if validation_prompt else ' ' }' for i, image in enumerate(images):
output: image.save(os.path.join(repo_folder, f"image_{i}.png"))
url: widget_dict.append(
"image_{i}.png" {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
""" )
if not images: else:
img_str += f""" widget_dict.append({"text": instance_prompt})
- text: '{instance_prompt}'
"""
embeddings_filename = f"{repo_folder}_emb" embeddings_filename = f"{repo_folder}_emb"
instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1)) instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1))
ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt)) ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt))
...@@ -137,24 +135,7 @@ pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_en ...@@ -137,24 +135,7 @@ pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_en
trigger_str += f""" trigger_str += f"""
to trigger concept `{key}` → use `{tokens}` in your prompt \n to trigger concept `{key}` → use `{tokens}` in your prompt \n
""" """
model_description = f"""
yaml = f"""---
tags:
- stable-diffusion
- stable-diffusion-diffusers
- diffusers-training
- text-to-image
- diffusers
- {lora}
- template:sd-lora
{img_str}
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
---
"""
model_card = f"""
# SD1.5 LoRA DreamBooth - {repo_id} # SD1.5 LoRA DreamBooth - {repo_id}
<Gallery /> <Gallery />
...@@ -202,8 +183,28 @@ Pivotal tuning was enabled: {train_text_encoder_ti}. ...@@ -202,8 +183,28 @@ Pivotal tuning was enabled: {train_text_encoder_ti}.
Special VAE used for training: {vae_path}. Special VAE used for training: {vae_path}.
""" """
with open(os.path.join(repo_folder, "README.md"), "w") as f: model_card = load_or_create_model_card(
f.write(yaml + model_card) repo_id_or_path=repo_id,
from_training=True,
license="openrail++",
base_model=base_model,
prompt=instance_prompt,
model_description=model_description,
inference=True,
widget=widget_dict,
)
tags = [
"text-to-image",
"diffusers",
"diffusers-training",
lora,
"template:sd-lora" "stable-diffusion",
"stable-diffusion-diffusers",
]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def import_model_class_from_model_name_or_path( def import_model_class_from_model_name_or_path(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment