"tests/vscode:/vscode.git/clone" did not exist on "ef2ea33c3bc061fffa8bc4ccd640306ca1a1847d"
Unverified Commit d4642144 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Let's make sure that dreambooth always uploads to the Hub (#3272)

* Update Dreambooth README

* Adapt all docs as well

* automatically write model card

* fix

* make style
parent 62906682
...@@ -98,7 +98,8 @@ accelerate launch train_dreambooth.py \ ...@@ -98,7 +98,8 @@ accelerate launch train_dreambooth.py \
--learning_rate=5e-6 \ --learning_rate=5e-6 \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--max_train_steps=400 --max_train_steps=400 \
--push_to_hub
``` ```
</pt> </pt>
<jax> <jax>
...@@ -161,7 +162,8 @@ accelerate launch train_dreambooth.py \ ...@@ -161,7 +162,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 --max_train_steps=800 \
--push_to_hub
``` ```
</pt> </pt>
<jax> <jax>
...@@ -225,7 +227,8 @@ accelerate launch train_dreambooth.py \ ...@@ -225,7 +227,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 --max_train_steps=800 \
--push_to_hub
``` ```
</pt> </pt>
<jax> <jax>
...@@ -387,7 +390,8 @@ accelerate launch train_dreambooth.py \ ...@@ -387,7 +390,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 --max_train_steps=800 \
--push_to_hub
``` ```
### 12GB GPU ### 12GB GPU
...@@ -418,7 +422,8 @@ accelerate launch train_dreambooth.py \ ...@@ -418,7 +422,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 --max_train_steps=800 \
--push_to_hub
``` ```
### 8 GB GPU ### 8 GB GPU
...@@ -464,7 +469,8 @@ accelerate launch train_dreambooth.py \ ...@@ -464,7 +469,8 @@ accelerate launch train_dreambooth.py \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 \ --max_train_steps=800 \
--mixed_precision=fp16 --mixed_precision=fp16 \
--push_to_hub
``` ```
## Inference ## Inference
......
...@@ -80,7 +80,8 @@ accelerate launch train_dreambooth.py \ ...@@ -80,7 +80,8 @@ accelerate launch train_dreambooth.py \
--learning_rate=5e-6 \ --learning_rate=5e-6 \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--max_train_steps=400 --max_train_steps=400 \
--push_to_hub
``` ```
### Training with prior-preservation loss ### Training with prior-preservation loss
...@@ -109,7 +110,8 @@ accelerate launch train_dreambooth.py \ ...@@ -109,7 +110,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 --max_train_steps=800 \
--push_to_hub
``` ```
...@@ -141,7 +143,8 @@ accelerate launch train_dreambooth.py \ ...@@ -141,7 +143,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 --max_train_steps=800 \
--push_to_hub
``` ```
...@@ -176,7 +179,8 @@ accelerate launch train_dreambooth.py \ ...@@ -176,7 +179,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 --max_train_steps=800 \
--push_to_hub
``` ```
...@@ -218,7 +222,8 @@ accelerate launch --mixed_precision="fp16" train_dreambooth.py \ ...@@ -218,7 +222,8 @@ accelerate launch --mixed_precision="fp16" train_dreambooth.py \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 --max_train_steps=800 \
--push_to_hub
``` ```
### Fine-tune text encoder with the UNet. ### Fine-tune text encoder with the UNet.
...@@ -251,7 +256,8 @@ accelerate launch train_dreambooth.py \ ...@@ -251,7 +256,8 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 --max_train_steps=800 \
--push_to_hub
``` ```
### Using DreamBooth for pipelines other than Stable Diffusion ### Using DreamBooth for pipelines other than Stable Diffusion
......
...@@ -61,6 +61,39 @@ check_min_version("0.17.0.dev0") ...@@ -61,6 +61,39 @@ check_min_version("0.17.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
instance_prompt: {prompt}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
- dreambooth
inference: true
---
"""
model_card = f"""
# DreamBooth - {repo_id}
This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
You can find some example images in the following. \n
{img_str}
DreamBooth for the text encoder was enabled: {train_text_encoder}.
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
logger.info( logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
...@@ -104,6 +137,8 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight ...@@ -104,6 +137,8 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
del pipeline del pipeline
torch.cuda.empty_cache() torch.cuda.empty_cache()
return images
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained( text_encoder_config = PretrainedConfig.from_pretrained(
...@@ -997,13 +1032,16 @@ def main(args): ...@@ -997,13 +1032,16 @@ def main(args):
global_step += 1 global_step += 1
if accelerator.is_main_process: if accelerator.is_main_process:
images = []
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
if args.validation_prompt is not None and global_step % args.validation_steps == 0: if args.validation_prompt is not None and global_step % args.validation_steps == 0:
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch) images = log_validation(
text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch
)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
...@@ -1024,6 +1062,14 @@ def main(args): ...@@ -1024,6 +1062,14 @@ def main(args):
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
save_model_card(
repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder,
prompt=args.instance_prompt,
repo_folder=args.output_dir,
)
upload_folder( upload_folder(
repo_id=repo_id, repo_id=repo_id,
folder_path=args.output_dir, folder_path=args.output_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment