Unverified Commit 3045fb27 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[DreamBooth] add text encoder LoRA support in the DreamBooth training script (#3130)

* add: LoRA text encoder support for DreamBooth example.

* fix initialization.

* fix: modification call.

* add: entry in the readme.

* use dog dataset from hub.

* fix: params to clip.

* add entry to the LoRA doc.

* add: tests for lora.

* remove unnecessary list comprehension./
parent 7b0ba482
...@@ -60,7 +60,18 @@ DreamBooth finetuning is very sensitive to hyperparameters and easy to overfit. ...@@ -60,7 +60,18 @@ DreamBooth finetuning is very sensitive to hyperparameters and easy to overfit.
<frameworkcontent> <frameworkcontent>
<pt> <pt>
Let's try DreamBooth with a [few images of a dog](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ); download and save them to a directory and then set the `INSTANCE_DIR` environment variable to that path: Let's try DreamBooth with a
[few images of a dog](https://huggingface.co/datasets/diffusers/dog-example);
download and save them to a directory and then set the `INSTANCE_DIR` environment variable to that path:
```python
local_dir = "./path_to_training_images"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
```bash ```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4" export MODEL_NAME="CompVis/stable-diffusion-v1-4"
......
...@@ -16,7 +16,9 @@ specific language governing permissions and limitations under the License. ...@@ -16,7 +16,9 @@ specific language governing permissions and limitations under the License.
<Tip warning={true}> <Tip warning={true}>
Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`]. Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`]. We also
support LoRA fine-tuning of the text encoder for DreamBooth in a limited capacity. For more details on how we support
LoRA fine-tuning of the text encoder, refer to the discussion on [this PR](https://github.com/huggingface/diffusers/pull/2918).
</Tip> </Tip>
...@@ -175,6 +177,11 @@ accelerate launch train_dreambooth_lora.py \ ...@@ -175,6 +177,11 @@ accelerate launch train_dreambooth_lora.py \
--push_to_hub --push_to_hub
``` ```
It's also possible to additionally fine-tune the text encoder with LoRA. This, in most cases, leads
to better results with a slight increase in the compute. To allow fine-tuning the text encoder with LoRA,
specify the `--train_text_encoder` while launching the `train_dreambooth_lora.py` script.
### Inference[[dreambooth-inference]] ### Inference[[dreambooth-inference]]
Now you can use the model for inference by loading the base model in the [`StableDiffusionPipeline`]: Now you can use the model for inference by loading the base model in the [`StableDiffusionPipeline`]:
......
...@@ -45,15 +45,28 @@ write_basic_config() ...@@ -45,15 +45,28 @@ write_basic_config()
### Dog toy example ### Dog toy example
Now let's get our dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. This will be our training data. Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
And launch the training using Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
And launch the training using:
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** **___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
```bash ```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4" export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images" export INSTANCE_DIR="dog"
export OUTPUT_DIR="path-to-save-model" export OUTPUT_DIR="path-to-save-model"
accelerate launch train_dreambooth.py \ accelerate launch train_dreambooth.py \
...@@ -77,7 +90,7 @@ According to the paper, it's recommended to generate `num_epochs * num_samples` ...@@ -77,7 +90,7 @@ According to the paper, it's recommended to generate `num_epochs * num_samples`
```bash ```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4" export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images" export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images" export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model" export OUTPUT_DIR="path-to-save-model"
...@@ -108,7 +121,7 @@ To install `bitandbytes` please refer to this [readme](https://github.com/TimDet ...@@ -108,7 +121,7 @@ To install `bitandbytes` please refer to this [readme](https://github.com/TimDet
```bash ```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4" export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images" export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images" export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model" export OUTPUT_DIR="path-to-save-model"
...@@ -141,7 +154,7 @@ It is possible to run dreambooth on a 12GB GPU by using the following optimizati ...@@ -141,7 +154,7 @@ It is possible to run dreambooth on a 12GB GPU by using the following optimizati
```bash ```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4" export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images" export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images" export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model" export OUTPUT_DIR="path-to-save-model"
...@@ -185,7 +198,7 @@ does not seem to be compatible with DeepSpeed at the moment. ...@@ -185,7 +198,7 @@ does not seem to be compatible with DeepSpeed at the moment.
```bash ```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4" export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images" export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images" export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model" export OUTPUT_DIR="path-to-save-model"
...@@ -217,7 +230,7 @@ ___Note: Training text encoder requires more memory, with this option the traini ...@@ -217,7 +230,7 @@ ___Note: Training text encoder requires more memory, with this option the traini
```bash ```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4" export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images" export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images" export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model" export OUTPUT_DIR="path-to-save-model"
...@@ -300,7 +313,7 @@ Now, you can launch the training. Here we will use [Stable Diffusion 1-5](https: ...@@ -300,7 +313,7 @@ Now, you can launch the training. Here we will use [Stable Diffusion 1-5](https:
```bash ```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5" export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export INSTANCE_DIR="path-to-instance-images" export INSTANCE_DIR="dog"
export OUTPUT_DIR="path-to-save-model" export OUTPUT_DIR="path-to-save-model"
``` ```
...@@ -342,6 +355,12 @@ The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dr ...@@ -342,6 +355,12 @@ The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dr
The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5). The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
You can use the `Step` slider to see how the model learned the features of our subject while the model trained. You can use the `Step` slider to see how the model learned the features of our subject while the model trained.
Optionally, we can also train additional LoRA layers for the text encoder. Specify the `train_text_encoder` argument above for that. If you're interested to know more about how we
enable this support, check out this [PR](https://github.com/huggingface/diffusers/pull/2918).
With the default hyperparameters from the above, the training seems to go in a positive direction. Check out [this panel](https://wandb.ai/sayakpaul/dreambooth-lora/reports/test-23-04-17-17-00-13---Vmlldzo0MDkwNjMy). The trained LoRA layers are available [here](https://huggingface.co/sayakpaul/dreambooth).
### Inference ### Inference
After training, LoRA weights can be loaded very easily into the original pipeline. First, you need to After training, LoRA weights can be loaded very easily into the original pipeline. First, you need to
...@@ -386,7 +405,7 @@ pip install -U -r requirements_flax.txt ...@@ -386,7 +405,7 @@ pip install -U -r requirements_flax.txt
```bash ```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export INSTANCE_DIR="path-to-instance-images" export INSTANCE_DIR="dog"
export OUTPUT_DIR="path-to-save-model" export OUTPUT_DIR="path-to-save-model"
python train_dreambooth_flax.py \ python train_dreambooth_flax.py \
...@@ -405,7 +424,7 @@ python train_dreambooth_flax.py \ ...@@ -405,7 +424,7 @@ python train_dreambooth_flax.py \
```bash ```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export INSTANCE_DIR="path-to-instance-images" export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images" export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model" export OUTPUT_DIR="path-to-save-model"
...@@ -429,7 +448,7 @@ python train_dreambooth_flax.py \ ...@@ -429,7 +448,7 @@ python train_dreambooth_flax.py \
```bash ```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export INSTANCE_DIR="path-to-instance-images" export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images" export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model" export OUTPUT_DIR="path-to-save-model"
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import argparse import argparse
import hashlib import hashlib
import itertools
import logging import logging
import math import math
import os import os
...@@ -43,12 +44,13 @@ from diffusers import ( ...@@ -43,12 +44,13 @@ from diffusers import (
DDPMScheduler, DDPMScheduler,
DiffusionPipeline, DiffusionPipeline,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
StableDiffusionPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.loaders import AttnProcsLayers from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
from diffusers.models.attention_processor import LoRAAttnProcessor from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -58,7 +60,7 @@ check_min_version("0.16.0.dev0") ...@@ -58,7 +60,7 @@ check_min_version("0.16.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None): def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None):
img_str = "" img_str = ""
for i, image in enumerate(images): for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png")) image.save(os.path.join(repo_folder, f"image_{i}.png"))
...@@ -83,6 +85,8 @@ inference: true ...@@ -83,6 +85,8 @@ inference: true
These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
{img_str} {img_str}
LoRA for the text encoder was enabled: {train_text_encoder}.
""" """
with open(os.path.join(repo_folder, "README.md"), "w") as f: with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card) f.write(yaml + model_card)
...@@ -219,6 +223,11 @@ def parse_args(input_args=None): ...@@ -219,6 +223,11 @@ def parse_args(input_args=None):
" cropped. The images will be resized to the resolution first before cropping." " cropped. The images will be resized to the resolution first before cropping."
), ),
) )
parser.add_argument(
"--train_text_encoder",
action="store_true",
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
)
parser.add_argument( parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
) )
...@@ -547,7 +556,13 @@ def main(args): ...@@ -547,7 +556,13 @@ def main(args):
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate.
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
raise ValueError(
"Gradient accumulation is not supported when training the text encoder in distributed training. "
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
)
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
...@@ -691,7 +706,7 @@ def main(args): ...@@ -691,7 +706,7 @@ def main(args):
# => 32 layers # => 32 layers
# Set correct lora layers # Set correct lora layers
lora_attn_procs = {} unet_lora_attn_procs = {}
for name in unet.attn_processors.keys(): for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"): if name.startswith("mid_block"):
...@@ -703,12 +718,33 @@ def main(args): ...@@ -703,12 +718,33 @@ def main(args):
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id] hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) unet_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
unet.set_attn_processor(lora_attn_procs) )
lora_layers = AttnProcsLayers(unet.attn_processors)
accelerator.register_for_checkpointing(lora_layers) unet.set_attn_processor(unet_lora_attn_procs)
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
accelerator.register_for_checkpointing(unet_lora_layers)
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
# we first load a dummy pipeline with the text encoder and then do the monkey-patching.
text_encoder_lora_layers = None
if args.train_text_encoder:
text_lora_attn_procs = {}
for name, module in text_encoder.named_modules():
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
text_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=module.out_features, cross_attention_dim=None
)
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
temp_pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, text_encoder=text_encoder
)
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
text_encoder = temp_pipeline.text_encoder
accelerator.register_for_checkpointing(unet_lora_layers)
del temp_pipeline
if args.scale_lr: if args.scale_lr:
args.learning_rate = ( args.learning_rate = (
...@@ -739,8 +775,13 @@ def main(args): ...@@ -739,8 +775,13 @@ def main(args):
optimizer_class = torch.optim.AdamW optimizer_class = torch.optim.AdamW
# Optimizer creation # Optimizer creation
params_to_optimize = (
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
if args.train_text_encoder
else unet_lora_layers.parameters()
)
optimizer = optimizer_class( optimizer = optimizer_class(
lora_layers.parameters(), params_to_optimize,
lr=args.learning_rate, lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
...@@ -784,8 +825,13 @@ def main(args): ...@@ -784,8 +825,13 @@ def main(args):
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( if args.train_text_encoder:
lora_layers, optimizer, train_dataloader, lr_scheduler unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler
)
else:
unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_layers, optimizer, train_dataloader, lr_scheduler
) )
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
...@@ -845,6 +891,8 @@ def main(args): ...@@ -845,6 +891,8 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
unet.train() unet.train()
if args.train_text_encoder:
text_encoder.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step # Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
...@@ -900,7 +948,11 @@ def main(args): ...@@ -900,7 +948,11 @@ def main(args):
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
params_to_clip = lora_layers.parameters() params_to_clip = (
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
if args.train_text_encoder
else unet_lora_layers.parameters()
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
...@@ -914,7 +966,14 @@ def main(args): ...@@ -914,7 +966,14 @@ def main(args):
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process: if accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) # We combine the text encoder and UNet LoRA parameters with a simple
# custom logic. `accelerator.save_state()` won't know that. So,
# use `LoraLoaderMixin.save_lora_weights()`.
LoraLoaderMixin.save_lora_weights(
save_directory=save_path,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
...@@ -970,7 +1029,12 @@ def main(args): ...@@ -970,7 +1029,12 @@ def main(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
unet = unet.to(torch.float32) unet = unet.to(torch.float32)
unet.save_attn_procs(args.output_dir) text_encoder = text_encoder.to(torch.float32)
LoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
)
# Final inference # Final inference
# Load previous pipeline # Load previous pipeline
...@@ -981,7 +1045,7 @@ def main(args): ...@@ -981,7 +1045,7 @@ def main(args):
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
# load attention processors # load attention processors
pipeline.unet.load_attn_procs(args.output_dir) pipeline.load_attn_procs(args.output_dir)
# run inference # run inference
if args.validation_prompt and args.num_validation_images > 0: if args.validation_prompt and args.num_validation_images > 0:
...@@ -1010,6 +1074,7 @@ def main(args): ...@@ -1010,6 +1074,7 @@ def main(args):
repo_id, repo_id,
images=images, images=images,
base_model=args.pretrained_model_name_or_path, base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder,
prompt=args.instance_prompt, prompt=args.instance_prompt,
repo_folder=args.output_dir, repo_folder=args.output_dir,
) )
......
...@@ -23,6 +23,7 @@ import tempfile ...@@ -23,6 +23,7 @@ import tempfile
import unittest import unittest
from typing import List from typing import List
import torch
from accelerate.utils import write_basic_config from accelerate.utils import write_basic_config
from diffusers import DiffusionPipeline, UNet2DConditionModel from diffusers import DiffusionPipeline, UNet2DConditionModel
...@@ -221,6 +222,68 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -221,6 +222,68 @@ class ExamplesTestsAccelerate(unittest.TestCase):
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6"))) self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
def test_dreambooth_lora(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"unet"` in their names.
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
self.assertTrue(starts_with_unet)
def test_dreambooth_lora_with_text_encoder(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--train_text_encoder
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
# the names of the keys of the state dict should either start with `unet`
# or `text_encoder`.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
keys = lora_state_dict.keys()
is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
self.assertTrue(is_correct_naming)
def test_custom_diffusion(self): def test_custom_diffusion(self):
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
test_args = f""" test_args = f"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment