Unverified Commit d4642144 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Let's make sure that dreambooth always uploads to the Hub (#3272)

* Update Dreambooth README

* Adapt all docs as well

* automatically write model card

* fix

* make style
parent 62906682
...@@ -98,7 +98,8 @@ accelerate launch train_dreambooth.py \ ...@@ -98,7 +98,8 @@ accelerate launch train_dreambooth.py \
--learning_rate=5e-6 \ --learning_rate=5e-6 \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--max_train_steps=400 --max_train_steps=400 \
--push_to_hub
``` ```
</pt> </pt>
<jax> <jax>
...@@ -161,7 +162,8 @@ accelerate launch train_dreambooth.py \ ...@@ -161,7 +162,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 --max_train_steps=800 \
--push_to_hub
``` ```
</pt> </pt>
<jax> <jax>
...@@ -225,7 +227,8 @@ accelerate launch train_dreambooth.py \ ...@@ -225,7 +227,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 --max_train_steps=800 \
--push_to_hub
``` ```
</pt> </pt>
<jax> <jax>
...@@ -387,7 +390,8 @@ accelerate launch train_dreambooth.py \ ...@@ -387,7 +390,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 --max_train_steps=800 \
--push_to_hub
``` ```
### 12GB GPU ### 12GB GPU
...@@ -418,7 +422,8 @@ accelerate launch train_dreambooth.py \ ...@@ -418,7 +422,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 --max_train_steps=800 \
--push_to_hub
``` ```
### 8 GB GPU ### 8 GB GPU
...@@ -464,7 +469,8 @@ accelerate launch train_dreambooth.py \ ...@@ -464,7 +469,8 @@ accelerate launch train_dreambooth.py \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 \ --max_train_steps=800 \
--mixed_precision=fp16 --mixed_precision=fp16 \
--push_to_hub
``` ```
## Inference ## Inference
......
...@@ -80,7 +80,8 @@ accelerate launch train_dreambooth.py \ ...@@ -80,7 +80,8 @@ accelerate launch train_dreambooth.py \
--learning_rate=5e-6 \ --learning_rate=5e-6 \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--max_train_steps=400 --max_train_steps=400 \
--push_to_hub
``` ```
### Training with prior-preservation loss ### Training with prior-preservation loss
...@@ -109,7 +110,8 @@ accelerate launch train_dreambooth.py \ ...@@ -109,7 +110,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 --max_train_steps=800 \
--push_to_hub
``` ```
...@@ -141,7 +143,8 @@ accelerate launch train_dreambooth.py \ ...@@ -141,7 +143,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 --max_train_steps=800 \
--push_to_hub
``` ```
...@@ -176,7 +179,8 @@ accelerate launch train_dreambooth.py \ ...@@ -176,7 +179,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 --max_train_steps=800 \
--push_to_hub
``` ```
...@@ -218,7 +222,8 @@ accelerate launch --mixed_precision="fp16" train_dreambooth.py \ ...@@ -218,7 +222,8 @@ accelerate launch --mixed_precision="fp16" train_dreambooth.py \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 --max_train_steps=800 \
--push_to_hub
``` ```
### Fine-tune text encoder with the UNet. ### Fine-tune text encoder with the UNet.
...@@ -251,7 +256,8 @@ accelerate launch train_dreambooth.py \ ...@@ -251,7 +256,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 --max_train_steps=800 \
--push_to_hub
``` ```
### Using DreamBooth for pipelines other than Stable Diffusion ### Using DreamBooth for pipelines other than Stable Diffusion
......
...@@ -61,6 +61,39 @@ check_min_version("0.17.0.dev0") ...@@ -61,6 +61,39 @@ check_min_version("0.17.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
instance_prompt: {prompt}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
- dreambooth
inference: true
---
"""
model_card = f"""
# DreamBooth - {repo_id}
This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
You can find some example images in the following. \n
{img_str}
DreamBooth for the text encoder was enabled: {train_text_encoder}.
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
logger.info( logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
...@@ -104,6 +137,8 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight ...@@ -104,6 +137,8 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
del pipeline del pipeline
torch.cuda.empty_cache() torch.cuda.empty_cache()
return images
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained( text_encoder_config = PretrainedConfig.from_pretrained(
...@@ -997,13 +1032,16 @@ def main(args): ...@@ -997,13 +1032,16 @@ def main(args):
global_step += 1 global_step += 1
if accelerator.is_main_process: if accelerator.is_main_process:
images = []
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
if args.validation_prompt is not None and global_step % args.validation_steps == 0: if args.validation_prompt is not None and global_step % args.validation_steps == 0:
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch) images = log_validation(
text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch
)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
...@@ -1024,6 +1062,14 @@ def main(args): ...@@ -1024,6 +1062,14 @@ def main(args):
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
save_model_card(
repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder,
prompt=args.instance_prompt,
repo_folder=args.output_dir,
)
upload_folder( upload_folder(
repo_id=repo_id, repo_id=repo_id,
folder_path=args.output_dir, folder_path=args.output_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment