Unverified Commit 743a5697 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[flux dreambooth lora training] make LoRA target modules configurable + small bug fix (#9646)

* make lora target modules configurable and change the default

* style

* make lora target modules configurable and change the default

* fix bug when using prodigy and training te

* fix mixed precision training as  proposed in https://github.com/huggingface/diffusers/pull/9565

 for full dreambooth as well

* add test and notes

* style

* address sayaks comments

* style

* fix test

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent db5b6a96
...@@ -170,6 +170,21 @@ accelerate launch train_dreambooth_lora_flux.py \ ...@@ -170,6 +170,21 @@ accelerate launch train_dreambooth_lora_flux.py \
--push_to_hub --push_to_hub
``` ```
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string
the exact modules for LoRA training. Here are some examples of target modules you can provide:
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
> [!NOTE]
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string:
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
> [!NOTE]
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
### Text Encoder Training ### Text Encoder Training
Alongside the transformer, fine-tuning of the CLIP text encoder is also supported. Alongside the transformer, fine-tuning of the CLIP text encoder is also supported.
......
...@@ -37,6 +37,7 @@ class DreamBoothLoRAFlux(ExamplesTestsAccelerate): ...@@ -37,6 +37,7 @@ class DreamBoothLoRAFlux(ExamplesTestsAccelerate):
instance_prompt = "photo" instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe" pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
script_path = "examples/dreambooth/train_dreambooth_lora_flux.py" script_path = "examples/dreambooth/train_dreambooth_lora_flux.py"
transformer_layer_type = "single_transformer_blocks.0.attn.to_k"
def test_dreambooth_lora_flux(self): def test_dreambooth_lora_flux(self):
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
...@@ -136,6 +137,43 @@ class DreamBoothLoRAFlux(ExamplesTestsAccelerate): ...@@ -136,6 +137,43 @@ class DreamBoothLoRAFlux(ExamplesTestsAccelerate):
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer) self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_layers(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lora_layers {self.transformer_layer_type}
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names. In this test, we only params of
# transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
starts_with_transformer = all(
key.startswith("transformer.single_transformer_blocks.0.attn.to_k") for key in lora_state_dict.keys()
)
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self): def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
test_args = f""" test_args = f"""
......
...@@ -161,7 +161,7 @@ def log_validation( ...@@ -161,7 +161,7 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}." f" {args.validation_prompt}."
) )
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
...@@ -1579,7 +1579,7 @@ def main(args): ...@@ -1579,7 +1579,7 @@ def main(args):
) )
# handle guidance # handle guidance
if transformer.config.guidance_embeds: if accelerator.unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0]) guidance = guidance.expand(model_input.shape[0])
else: else:
...@@ -1693,6 +1693,8 @@ def main(args): ...@@ -1693,6 +1693,8 @@ def main(args):
# create pipeline # create pipeline
if not args.train_text_encoder: if not args.train_text_encoder:
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
else: # even when training the text encoder we're only training text encoder one else: # even when training the text encoder we're only training text encoder one
text_encoder_two = text_encoder_cls_two.from_pretrained( text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
......
...@@ -554,6 +554,15 @@ def parse_args(input_args=None): ...@@ -554,6 +554,15 @@ def parse_args(input_args=None):
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
) )
parser.add_argument(
"--lora_layers",
type=str,
default=None,
help=(
'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only'
),
)
parser.add_argument( parser.add_argument(
"--adam_epsilon", "--adam_epsilon",
type=float, type=float,
...@@ -1186,12 +1195,30 @@ def main(args): ...@@ -1186,12 +1195,30 @@ def main(args):
if args.train_text_encoder: if args.train_text_encoder:
text_encoder_one.gradient_checkpointing_enable() text_encoder_one.gradient_checkpointing_enable()
# now we will add new LoRA weights to the attention layers if args.lora_layers is not None:
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
else:
target_modules = [
"attn.to_k",
"attn.to_q",
"attn.to_v",
"attn.to_out.0",
"attn.add_k_proj",
"attn.add_q_proj",
"attn.add_v_proj",
"attn.to_add_out",
"ff.net.0.proj",
"ff.net.2",
"ff_context.net.0.proj",
"ff_context.net.2",
]
# now we will add new LoRA weights the transformer layers
transformer_lora_config = LoraConfig( transformer_lora_config = LoraConfig(
r=args.rank, r=args.rank,
lora_alpha=args.rank, lora_alpha=args.rank,
init_lora_weights="gaussian", init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"], target_modules=target_modules,
) )
transformer.add_adapter(transformer_lora_config) transformer.add_adapter(transformer_lora_config)
if args.train_text_encoder: if args.train_text_encoder:
...@@ -1367,7 +1394,7 @@ def main(args): ...@@ -1367,7 +1394,7 @@ def main(args):
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate." f"When using prodigy only learning_rate is used as the initial learning rate."
) )
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be # changes the learning rate of text_encoder_parameters_one to be
# --learning_rate # --learning_rate
params_to_optimize[1]["lr"] = args.learning_rate params_to_optimize[1]["lr"] = args.learning_rate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment