Unverified Commit c6e08ecd authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[Sd3 Dreambooth LoRA] Add text encoder training for the clip encoders (#8630)



* add clip text-encoder training

* no dora

* text encoder traing fixes

* text encoder traing fixes

* text encoder training fixes

* text encoder training fixes

* text encoder training fixes

* text encoder training fixes

* add text_encoder layers to save_lora

* style

* fix imports

* style

* fix text encoder

* review changes

* review changes

* review changes

* minor change

* add lora tag

* style

* add readme notes

* add tests for clip encoders

* style

* typo

* fixes

* style

* Update tests/lora/test_lora_layers_sd3.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update examples/dreambooth/README_sd3.md
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* minor readme change

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 4ad7a1f5
...@@ -147,6 +147,40 @@ accelerate launch train_dreambooth_lora_sd3.py \ ...@@ -147,6 +147,40 @@ accelerate launch train_dreambooth_lora_sd3.py \
--push_to_hub --push_to_hub
``` ```
### Text Encoder Training
Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported.
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
> [!NOTE]
> SD3 has three text encoders (CLIP L/14, OpenCLIP bigG/14, and T5-v1.1-XXL).
By enabling `--train_text_encoder`, LoRA fine-tuning of both **CLIP encoders** is performed. At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled.
To perform DreamBooth LoRA with text-encoder training, run:
```bash
export MODEL_NAME="stabilityai/stable-diffusion-3-medium-diffusers"
export OUTPUT_DIR="trained-sd3-lora"
accelerate launch train_dreambooth_lora_sd3.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--output_dir=$OUTPUT_DIR \
--dataset_name="Norod78/Yarn-art-style" \
--instance_prompt="a photo of TOK yarn art dog" \
--resolution=1024 \
--train_batch_size=1 \
--train_text_encoder\
--gradient_accumulation_steps=1 \
--optimizer="prodigy"\
--learning_rate=1.0 \
--text_encoder_lr=1.0 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=1500 \
--rank=32 \
--seed="0" \
--push_to_hub
```
## Other notes ## Other notes
We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities. We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.
\ No newline at end of file
...@@ -1601,6 +1601,8 @@ class SD3LoraLoaderMixin: ...@@ -1601,6 +1601,8 @@ class SD3LoraLoaderMixin:
cls, cls,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, torch.nn.Module] = None, transformer_lora_layers: Dict[str, torch.nn.Module] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True, is_main_process: bool = True,
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
...@@ -1632,12 +1634,20 @@ class SD3LoraLoaderMixin: ...@@ -1632,12 +1634,20 @@ class SD3LoraLoaderMixin:
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict return layers_state_dict
if not transformer_lora_layers: if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError("You must pass `transformer_lora_layers`.") raise ValueError(
"You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
)
if transformer_lora_layers: if transformer_lora_layers:
state_dict.update(pack_weights(transformer_lora_layers, cls.transformer_name)) state_dict.update(pack_weights(transformer_lora_layers, cls.transformer_name))
if text_encoder_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
if text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
# Save the model # Save the model
cls.write_lora_layers( cls.write_lora_layers(
state_dict=state_dict, state_dict=state_dict,
......
...@@ -137,6 +137,15 @@ class SD3LoRATests(unittest.TestCase): ...@@ -137,6 +137,15 @@ class SD3LoRATests(unittest.TestCase):
) )
return lora_config return lora_config
def get_lora_config_for_text_encoders(self):
text_lora_config = LoraConfig(
r=4,
lora_alpha=4,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
return text_lora_config
def test_simple_inference_with_transformer_lora_save_load(self): def test_simple_inference_with_transformer_lora_save_load(self):
components = self.get_dummy_components() components = self.get_dummy_components()
transformer_config = self.get_lora_config_for_transformer() transformer_config = self.get_lora_config_for_transformer()
...@@ -173,6 +182,55 @@ class SD3LoRATests(unittest.TestCase): ...@@ -173,6 +182,55 @@ class SD3LoRATests(unittest.TestCase):
"Loading from saved checkpoints should give same results.", "Loading from saved checkpoints should give same results.",
) )
def test_simple_inference_with_clip_encoders_lora_save_load(self):
components = self.get_dummy_components()
transformer_config = self.get_lora_config_for_transformer()
text_encoder_config = self.get_lora_config_for_text_encoders()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
pipe.transformer.add_adapter(transformer_config)
pipe.text_encoder.add_adapter(text_encoder_config)
pipe.text_encoder_2.add_adapter(text_encoder_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder.")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2.")
inputs = self.get_dummy_inputs(torch_device)
images_lora = pipe(**inputs).images
with tempfile.TemporaryDirectory() as tmpdirname:
transformer_state_dict = get_peft_model_state_dict(pipe.transformer)
text_encoder_one_state_dict = get_peft_model_state_dict(pipe.text_encoder)
text_encoder_two_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
transformer_lora_layers=transformer_state_dict,
text_encoder_lora_layers=text_encoder_one_state_dict,
text_encoder_2_lora_layers=text_encoder_two_state_dict,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
inputs = self.get_dummy_inputs(torch_device)
images_lora_from_pretrained = pipe(**inputs).images
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text_encoder_one")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text_encoder_two")
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
def test_simple_inference_with_transformer_lora_and_scale(self): def test_simple_inference_with_transformer_lora_and_scale(self):
components = self.get_dummy_components() components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer() transformer_lora_config = self.get_lora_config_for_transformer()
...@@ -206,6 +264,44 @@ class SD3LoRATests(unittest.TestCase): ...@@ -206,6 +264,44 @@ class SD3LoRATests(unittest.TestCase):
"Lora + 0 scale should lead to same result as no LoRA", "Lora + 0 scale should lead to same result as no LoRA",
) )
def test_simple_inference_with_clip_encoders_lora_and_scale(self):
components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer()
text_encoder_config = self.get_lora_config_for_text_encoders()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_no_lora = pipe(**inputs).images
pipe.transformer.add_adapter(transformer_lora_config)
pipe.text_encoder.add_adapter(text_encoder_config)
pipe.text_encoder_2.add_adapter(text_encoder_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text_encoder_one")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text_encoder_two")
inputs = self.get_dummy_inputs(torch_device)
output_lora = pipe(**inputs).images
self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
inputs = self.get_dummy_inputs(torch_device)
output_lora_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.5}).images
self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
inputs = self.get_dummy_inputs(torch_device)
output_lora_0_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.0}).images
self.assertTrue(
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA",
)
def test_simple_inference_with_transformer_fused(self): def test_simple_inference_with_transformer_fused(self):
components = self.get_dummy_components() components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer() transformer_lora_config = self.get_lora_config_for_transformer()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment