"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a216b0bb7fbf713e348edb030e865c2703965bd2"
Unverified Commit 1bc6f3dc authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[LoRA training] update metadata use for lora alpha + README (#11723)



* lora alpha

* Apply style fixes

* Update examples/advanced_diffusion_training/README_flux.md
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* fix readme format

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 79bd7ecc
...@@ -76,6 +76,24 @@ This command will prompt you for a token. Copy-paste yours from your [settings/t ...@@ -76,6 +76,24 @@ This command will prompt you for a token. Copy-paste yours from your [settings/t
> `pip install wandb` > `pip install wandb`
> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`. > Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`.
### LoRA Rank and Alpha
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
- lora_alpha vs. rank:
This ratio dictates the LoRA's effective strength:
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
> [!TIP]
> A common starting point is to set `lora_alpha` equal to `rank`.
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
> to give the LoRA updates more influence without increasing parameter count.
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
### Target Modules ### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them. When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import logging import logging
import os import os
import sys import sys
...@@ -20,6 +21,8 @@ import tempfile ...@@ -20,6 +21,8 @@ import tempfile
import safetensors import safetensors
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
sys.path.append("..") sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
...@@ -281,3 +284,45 @@ class DreamBoothLoRAFluxAdvanced(ExamplesTestsAccelerate): ...@@ -281,3 +284,45 @@ class DreamBoothLoRAFluxAdvanced(ExamplesTestsAccelerate):
run_command(self._launch_args + resume_run_args) run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
def test_dreambooth_lora_with_metadata(self):
# Use a `lora_alpha` that is different from `rank`.
lora_alpha = 8
rank = 4
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_alpha={lora_alpha}
--rank={rank}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(state_dict_file))
# Check if the metadata was properly serialized.
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
if raw:
raw = json.loads(raw)
loaded_lora_alpha = raw["transformer.lora_alpha"]
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)
...@@ -55,6 +55,7 @@ from diffusers import ( ...@@ -55,6 +55,7 @@ from diffusers import (
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import ( from diffusers.training_utils import (
_collate_lora_metadata,
_set_state_dict_into_text_encoder, _set_state_dict_into_text_encoder,
cast_training_params, cast_training_params,
compute_density_for_timestep_sampling, compute_density_for_timestep_sampling,
...@@ -431,6 +432,13 @@ def parse_args(input_args=None): ...@@ -431,6 +432,13 @@ def parse_args(input_args=None):
help=("The dimension of the LoRA update matrices."), help=("The dimension of the LoRA update matrices."),
) )
parser.add_argument(
"--lora_alpha",
type=int,
default=4,
help="LoRA alpha to be used for additional scaling.",
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument( parser.add_argument(
...@@ -1556,7 +1564,7 @@ def main(args): ...@@ -1556,7 +1564,7 @@ def main(args):
# now we will add new LoRA weights to the attention layers # now we will add new LoRA weights to the attention layers
transformer_lora_config = LoraConfig( transformer_lora_config = LoraConfig(
r=args.rank, r=args.rank,
lora_alpha=args.rank, lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout, lora_dropout=args.lora_dropout,
init_lora_weights="gaussian", init_lora_weights="gaussian",
target_modules=target_modules, target_modules=target_modules,
...@@ -1565,7 +1573,7 @@ def main(args): ...@@ -1565,7 +1573,7 @@ def main(args):
if args.train_text_encoder: if args.train_text_encoder:
text_lora_config = LoraConfig( text_lora_config = LoraConfig(
r=args.rank, r=args.rank,
lora_alpha=args.rank, lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout, lora_dropout=args.lora_dropout,
init_lora_weights="gaussian", init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
...@@ -1582,13 +1590,15 @@ def main(args): ...@@ -1582,13 +1590,15 @@ def main(args):
if accelerator.is_main_process: if accelerator.is_main_process:
transformer_lora_layers_to_save = None transformer_lora_layers_to_save = None
text_encoder_one_lora_layers_to_save = None text_encoder_one_lora_layers_to_save = None
modules_to_save = {}
for model in models: for model in models:
if isinstance(model, type(unwrap_model(transformer))): if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model) transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
elif isinstance(model, type(unwrap_model(text_encoder_one))): elif isinstance(model, type(unwrap_model(text_encoder_one))):
if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["text_encoder"] = model
elif isinstance(model, type(unwrap_model(text_encoder_two))): elif isinstance(model, type(unwrap_model(text_encoder_two))):
pass # when --train_text_encoder_ti and --enable_t5_ti we don't save the layers pass # when --train_text_encoder_ti and --enable_t5_ti we don't save the layers
else: else:
...@@ -1601,6 +1611,7 @@ def main(args): ...@@ -1601,6 +1611,7 @@ def main(args):
output_dir, output_dir,
transformer_lora_layers=transformer_lora_layers_to_save, transformer_lora_layers=transformer_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
**_collate_lora_metadata(modules_to_save),
) )
if args.train_text_encoder_ti: if args.train_text_encoder_ti:
embedding_handler.save_embeddings(f"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors") embedding_handler.save_embeddings(f"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors")
...@@ -2359,16 +2370,19 @@ def main(args): ...@@ -2359,16 +2370,19 @@ def main(args):
# Save the lora layers # Save the lora layers
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
modules_to_save = {}
transformer = unwrap_model(transformer) transformer = unwrap_model(transformer)
if args.upcast_before_saving: if args.upcast_before_saving:
transformer.to(torch.float32) transformer.to(torch.float32)
else: else:
transformer = transformer.to(weight_dtype) transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer) transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer
if args.train_text_encoder: if args.train_text_encoder:
text_encoder_one = unwrap_model(text_encoder_one) text_encoder_one = unwrap_model(text_encoder_one)
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
modules_to_save["text_encoder"] = text_encoder_one
else: else:
text_encoder_lora_layers = None text_encoder_lora_layers = None
...@@ -2377,6 +2391,7 @@ def main(args): ...@@ -2377,6 +2391,7 @@ def main(args):
save_directory=args.output_dir, save_directory=args.output_dir,
transformer_lora_layers=transformer_lora_layers, transformer_lora_layers=transformer_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_lora_layers=text_encoder_lora_layers,
**_collate_lora_metadata(modules_to_save),
) )
if args.train_text_encoder_ti: if args.train_text_encoder_ti:
......
...@@ -170,6 +170,23 @@ accelerate launch train_dreambooth_lora_flux.py \ ...@@ -170,6 +170,23 @@ accelerate launch train_dreambooth_lora_flux.py \
--push_to_hub --push_to_hub
``` ```
### LoRA Rank and Alpha
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
- lora_alpha vs. rank:
This ratio dictates the LoRA's effective strength:
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
> [!TIP]
> A common starting point is to set `lora_alpha` equal to `rank`.
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
> to give the LoRA updates more influence without increasing parameter count.
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
### Target Modules ### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them. When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment