Unverified Commit 777063e1 authored by Bhavay Malhotra's avatar Bhavay Malhotra Committed by GitHub
Browse files

Update textual_inversion.py (#6952)



* Update textual_inversion.py

* Apply suggestions from code review

* Update textual_inversion.py

* Update textual_inversion.py

* Update textual_inversion.py

* Update textual_inversion.py

* Update examples/textual_inversion/textual_inversion.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update textual_inversion.py

* styling

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 104afbce
...@@ -53,6 +53,7 @@ from diffusers import ( ...@@ -53,6 +53,7 @@ from diffusers import (
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -84,32 +85,30 @@ check_min_version("0.27.0.dev0") ...@@ -84,32 +85,30 @@ check_min_version("0.27.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None): def save_model_card(repo_id: str, images: list = None, base_model: str = None, repo_folder: str = None):
img_str = "" img_str = ""
if images is not None:
for i, image in enumerate(images): for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png")) image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n" img_str += f"![img_{i}](./image_{i}.png)\n"
model_description = f"""
yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
- textual_inversion
inference: true
---
"""
model_card = f"""
# Textual inversion text2image fine-tuning - {repo_id} # Textual inversion text2image fine-tuning - {repo_id}
These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n
{img_str} {img_str}
""" """
with open(os.path.join(repo_folder, "README.md"), "w") as f: model_card = load_or_create_model_card(
f.write(yaml + model_card) repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=base_model,
model_description=model_description,
inference=True,
)
tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "textual_inversion"]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment