Unverified Commit 365e8461 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[SDXL DreamBooth LoRA] add support for text encoder fine-tuning (#4097)



* Allow low precision sd xl

* finish

* finish

* feat: initial draft for supporting text encoder lora finetuning for SDXL DreamBooth

* fix: variable assignments.

* add: autocast block.

* add debugging

* vae dtype hell

* fix: vae dtype hell.

* fix: vae dtype hell 3.

* clean up

* lora text encoder loader.

* fix: unwrapping models.

* add: tests.

* docs.

* handle unexpected keys.

* fix vae dtype in the final inference.

* fix scope problem.

* fix: save_model_card args.

* initialize: prefix to None.

* fix: dtype issues.

* apply gixes.

* debgging.

* debugging

* debugging

* debugging

* debugging

* debugging

* add: fast tests.

* pre-tokenize.

* address: will's comments.

* fix: loader and tests.

* fix: dataloader.

* simplify dataloader.

* length.

* simplification.

* make style && make quality

* simplify state_dict munging

* fix: tests.

* fix: state_dict packing.

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent fed12376
......@@ -164,6 +164,17 @@ Here's a side-by-side comparison of the with and without Refiner pipeline output
|---|---|
| ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/sks_dog.png) | ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/refined_sks_dog.png) |
### Training with text encoder(s)
Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
* SDXL has two text encoders. So, we fine-tune both using LoRA.
* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory.
### Specifying a better VAE
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
## Notes
In our experiments we found that SDXL yields very good initial results using the default settings of the script. We didn't explore further hyper-parameter tuning experiments, but we do encourage the community to explore this avenue further and share their results with us 🤗
......
......@@ -385,6 +385,42 @@ class ExamplesTestsAccelerate(unittest.TestCase):
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
self.assertTrue(starts_with_unet)
def test_dreambooth_lora_sdxl_with_text_encoder(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--train_text_encoder
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names.
keys = lora_state_dict.keys()
starts_with_unet = all(
k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys
)
self.assertTrue(starts_with_unet)
def test_custom_diffusion(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
......
......@@ -59,7 +59,7 @@ if is_safetensors_available():
import safetensors
if is_transformers_available():
from transformers import CLIPTextModel, PreTrainedModel, PreTrainedTokenizer
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer
if is_accelerate_available():
from accelerate import init_empty_weights
......@@ -108,7 +108,7 @@ class PatchedLoraProjection(nn.Module):
def text_encoder_attn_modules(text_encoder):
attn_modules = []
if isinstance(text_encoder, CLIPTextModel):
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
......@@ -1016,18 +1016,20 @@ class LoraLoaderMixin:
warnings.warn(warn_message)
@classmethod
def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0):
def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, prefix=None, lora_scale=1.0):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The key shoult be prefixed with an
A standard state dict containing the lora layer parameters. The key should be prefixed with an
additional `text_encoder` to distinguish between unet lora layers.
network_alpha (`float`):
See `LoRALinearLayer` for more details.
text_encoder (`CLIPTextModel`):
The text encoder model to load the LoRA layers into.
prefix (`str`):
Expected prefix of the `text_encoder` in the `state_dict`.
lora_scale (`float`):
How much to scale the output of the lora linear layer before it is added with the output of the regular
lora layer.
......@@ -1037,14 +1039,16 @@ class LoraLoaderMixin:
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
prefix = cls.text_encoder_name if prefix is None else prefix
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(cls.text_encoder_name)]
text_encoder_keys = [k for k in keys if k.startswith(prefix)]
text_encoder_lora_state_dict = {
k.replace(f"{cls.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {cls.text_encoder_name}.")
logger.info(f"Loading {prefix}.")
if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()):
# Convert from the old naming convention to the new naming convention.
......@@ -1184,23 +1188,10 @@ class LoraLoaderMixin:
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
"""
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if save_function is None:
if safe_serialization:
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
else:
save_function = torch.save
os.makedirs(save_directory, exist_ok=True)
# Create a flat dictionary.
state_dict = {}
# Populate the dictionary.
if unet_lora_layers is not None:
weights = (
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
......@@ -1222,6 +1213,38 @@ class LoraLoaderMixin:
state_dict.update(text_encoder_lora_state_dict)
# Save the model
self.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
def write_lora_layers(
state_dict: Dict[str, torch.Tensor],
save_directory: str,
is_main_process: bool,
weight_name: str,
save_function: Callable,
safe_serialization: bool,
):
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if save_function is None:
if safe_serialization:
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
else:
save_function = torch.save
os.makedirs(save_directory, exist_ok=True)
if weight_name is None:
if safe_serialization:
weight_name = LORA_WEIGHT_NAME_SAFE
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import inspect
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
......@@ -841,3 +842,66 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
return (image,)
return StableDiffusionXLPipelineOutput(images=image)
# Overrride to properly handle the loading and unloading of the additional text encoder.
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alpha=network_alpha,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
)
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
if len(text_encoder_2_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_2_state_dict,
network_alpha=network_alpha,
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
)
@classmethod
def save_lora_weights(
self,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = False,
):
state_dict = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
state_dict.update(pack_weights(unet_lora_layers, "unet"))
if text_encoder_lora_layers and text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
self.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
def _remove_text_encoder_monkey_patch(self):
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
......@@ -21,9 +21,16 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub.repocard import RepoCard
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import (
AutoencoderKL,
DDIMScheduler,
EulerDiscreteScheduler,
StableDiffusionPipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin, PatchedLoraProjection, text_encoder_attn_modules
from diffusers.models.attention_processor import (
Attention,
......@@ -399,7 +406,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
)
self.assertIsInstance(module.processor, attn_proc_class)
def test_unload_lora(self):
def test_unload_lora_sd(self):
pipeline_components, lora_components = self.get_dummy_components()
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
sd_pipe = StableDiffusionPipeline(**pipeline_components)
......@@ -503,6 +510,175 @@ class LoraLoaderMixinTests(unittest.TestCase):
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
class SDXLLoraLoaderMixinTests(unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
# SD2-specific config below
attention_head_dim=(2, 4),
use_linear_projection=True,
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64,
)
scheduler = EulerDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
steps_offset=1,
beta_schedule="scaled_linear",
timestep_spacing="leading",
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
sample_size=128,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
# SD2-specific config below
hidden_act="gelu",
projection_dim=32,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
text_encoder_one_lora_layers = create_text_encoder_lora_layers(text_encoder)
text_encoder_two_lora_layers = create_text_encoder_lora_layers(text_encoder_2)
pipeline_components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
}
lora_components = {
"unet_lora_layers": unet_lora_layers,
"text_encoder_one_lora_layers": text_encoder_one_lora_layers,
"text_encoder_two_lora_layers": text_encoder_two_lora_layers,
"unet_lora_attn_procs": unet_lora_attn_procs,
}
return pipeline_components, lora_components
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 10
num_channels = 4
sizes = (32, 32)
generator = torch.manual_seed(0)
noise = floats_tensor((batch_size, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "np",
}
if with_generator:
pipeline_inputs.update({"generator": generator})
return noise, input_ids, pipeline_inputs
def test_lora_save_load(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
_, _, pipeline_inputs = self.get_dummy_inputs()
original_images = sd_pipe(**pipeline_inputs).images
orig_image_slice = original_images[0, -3:, -3:, -1]
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs).images
lora_image_slice = lora_images[0, -3:, -3:, -1]
# Outputs shouldn't match.
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
def test_unload_lora_sdxl(self):
pipeline_components, lora_components = self.get_dummy_components()
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
orig_image_slice = original_images[0, -3:, -3:, -1]
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
lora_image_slice = lora_images[0, -3:, -3:, -1]
# Unload LoRA parameters.
sd_pipe.unload_lora_weights()
original_images_two = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
orig_image_slice_two = original_images_two[0, -3:, -3:, -1]
assert not np.allclose(
orig_image_slice, lora_image_slice
), "LoRA parameters should lead to a different image slice."
assert not np.allclose(
orig_image_slice_two, lora_image_slice
), "LoRA parameters should lead to a different image slice."
assert np.allclose(
orig_image_slice, orig_image_slice_two, atol=1e-3
), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters."
@slow
@require_torch_gpu
class LoraIntegrationTests(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment