"app/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "45bf83ff58b8715a4d2bf6519e724db6ac0fc7e4"
Unverified Commit 04ddad48 authored by Batuhan Taskaya's avatar Batuhan Taskaya Committed by GitHub
Browse files

Add 'rank' parameter to Dreambooth LoRA training script (#3945)

parent 03d829d5
...@@ -436,6 +436,12 @@ def parse_args(input_args=None): ...@@ -436,6 +436,12 @@ def parse_args(input_args=None):
default=None, default=None,
help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
) )
parser.add_argument(
"--rank",
type=int,
default=4,
help=("The dimension of the LoRA update matrices."),
)
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
...@@ -845,7 +851,9 @@ def main(args): ...@@ -845,7 +851,9 @@ def main(args):
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
) )
unet_lora_attn_procs[name] = lora_attn_processor_class( unet_lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=args.rank,
) )
unet.set_attn_processor(unet_lora_attn_procs) unet.set_attn_processor(unet_lora_attn_procs)
...@@ -860,7 +868,9 @@ def main(args): ...@@ -860,7 +868,9 @@ def main(args):
for name, module in text_encoder.named_modules(): for name, module in text_encoder.named_modules():
if name.endswith(TEXT_ENCODER_ATTN_MODULE): if name.endswith(TEXT_ENCODER_ATTN_MODULE):
text_lora_attn_procs[name] = LoRAAttnProcessor( text_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=module.out_proj.out_features, cross_attention_dim=None hidden_size=module.out_proj.out_features,
cross_attention_dim=None,
rank=args.rank,
) )
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
temp_pipeline = DiffusionPipeline.from_pretrained( temp_pipeline = DiffusionPipeline.from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment