Unverified Commit 8e69708b authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Examples/DreamBooth] refactor save_model_card utility in dreambooth examples (#3543)

refactor save_model_card utility in dreambooth examples.
parent db56f8a4
...@@ -46,6 +46,7 @@ from diffusers import ( ...@@ -46,6 +46,7 @@ from diffusers import (
DDPMScheduler, DDPMScheduler,
DiffusionPipeline, DiffusionPipeline,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
StableDiffusionPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
...@@ -62,7 +63,15 @@ check_min_version("0.17.0.dev0") ...@@ -62,7 +63,15 @@ check_min_version("0.17.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None): def save_model_card(
repo_id: str,
images=None,
base_model=str,
train_text_encoder=False,
prompt=str,
repo_folder=None,
pipeline: DiffusionPipeline = None,
):
img_str = "" img_str = ""
for i, image in enumerate(images): for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png")) image.save(os.path.join(repo_folder, f"image_{i}.png"))
...@@ -74,8 +83,8 @@ license: creativeml-openrail-m ...@@ -74,8 +83,8 @@ license: creativeml-openrail-m
base_model: {base_model} base_model: {base_model}
instance_prompt: {prompt} instance_prompt: {prompt}
tags: tags:
- stable-diffusion - {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
- stable-diffusion-diffusers - {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
- text-to-image - text-to-image
- diffusers - diffusers
- dreambooth - dreambooth
...@@ -1297,6 +1306,7 @@ def main(args): ...@@ -1297,6 +1306,7 @@ def main(args):
train_text_encoder=args.train_text_encoder, train_text_encoder=args.train_text_encoder,
prompt=args.instance_prompt, prompt=args.instance_prompt,
repo_folder=args.output_dir, repo_folder=args.output_dir,
pipeline=pipeline,
) )
upload_folder( upload_folder(
repo_id=repo_id, repo_id=repo_id,
......
...@@ -68,7 +68,15 @@ check_min_version("0.17.0.dev0") ...@@ -68,7 +68,15 @@ check_min_version("0.17.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None): def save_model_card(
repo_id: str,
images=None,
base_model=str,
train_text_encoder=False,
prompt=str,
repo_folder=None,
pipeline: DiffusionPipeline = None,
):
img_str = "" img_str = ""
for i, image in enumerate(images): for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png")) image.save(os.path.join(repo_folder, f"image_{i}.png"))
...@@ -80,8 +88,8 @@ license: creativeml-openrail-m ...@@ -80,8 +88,8 @@ license: creativeml-openrail-m
base_model: {base_model} base_model: {base_model}
instance_prompt: {prompt} instance_prompt: {prompt}
tags: tags:
- stable-diffusion - {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
- stable-diffusion-diffusers - {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
- text-to-image - text-to-image
- diffusers - diffusers
- lora - lora
...@@ -844,7 +852,7 @@ def main(args): ...@@ -844,7 +852,7 @@ def main(args):
hidden_size=module.out_features, cross_attention_dim=None hidden_size=module.out_features, cross_attention_dim=None
) )
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
temp_pipeline = StableDiffusionPipeline.from_pretrained( temp_pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, text_encoder=text_encoder args.pretrained_model_name_or_path, text_encoder=text_encoder
) )
temp_pipeline._modify_text_encoder(text_lora_attn_procs) temp_pipeline._modify_text_encoder(text_lora_attn_procs)
...@@ -1332,6 +1340,7 @@ def main(args): ...@@ -1332,6 +1340,7 @@ def main(args):
train_text_encoder=args.train_text_encoder, train_text_encoder=args.train_text_encoder,
prompt=args.instance_prompt, prompt=args.instance_prompt,
repo_folder=args.output_dir, repo_folder=args.output_dir,
pipeline=pipeline,
) )
upload_folder( upload_folder(
repo_id=repo_id, repo_id=repo_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment