Unverified Commit 04ddad48 authored by Batuhan Taskaya's avatar Batuhan Taskaya Committed by GitHub
Browse files

Add 'rank' parameter to Dreambooth LoRA training script (#3945)

parent 03d829d5
......@@ -436,6 +436,12 @@ def parse_args(input_args=None):
default=None,
help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
)
parser.add_argument(
"--rank",
type=int,
default=4,
help=("The dimension of the LoRA update matrices."),
)
if input_args is not None:
args = parser.parse_args(input_args)
......@@ -845,7 +851,9 @@ def main(args):
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
unet_lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=args.rank,
)
unet.set_attn_processor(unet_lora_attn_procs)
......@@ -860,7 +868,9 @@ def main(args):
for name, module in text_encoder.named_modules():
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
text_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=module.out_proj.out_features, cross_attention_dim=None
hidden_size=module.out_proj.out_features,
cross_attention_dim=None,
rank=args.rank,
)
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
temp_pipeline = DiffusionPipeline.from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment