Unverified Commit 4188f306 authored by Levi McCallum's avatar Levi McCallum Committed by GitHub
Browse files

Add rank argument to train_dreambooth_lora_sdxl.py (#4343)

* Add rank argument to train_dreambooth_lora_sdxl.py

* Update train_dreambooth_lora_sdxl.py
parent 0b4430e8
...@@ -402,6 +402,12 @@ def parse_args(input_args=None): ...@@ -402,6 +402,12 @@ def parse_args(input_args=None):
parser.add_argument( parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
) )
parser.add_argument(
"--rank",
type=int,
default=4,
help=("The dimension of the LoRA update matrices."),
)
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
...@@ -767,7 +773,9 @@ def main(args): ...@@ -767,7 +773,9 @@ def main(args):
lora_attn_processor_class = ( lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
) )
module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) module = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank
)
unet_lora_attn_procs[name] = module unet_lora_attn_procs[name] = module
unet_lora_parameters.extend(module.parameters()) unet_lora_parameters.extend(module.parameters())
...@@ -777,8 +785,12 @@ def main(args): ...@@ -777,8 +785,12 @@ def main(args):
# So, instead, we monkey-patch the forward calls of its attention-blocks. # So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder: if args.train_text_encoder:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(text_encoder_one, dtype=torch.float32) text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(text_encoder_two, dtype=torch.float32) text_encoder_one, dtype=torch.float32, rank=args.rank
)
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
text_encoder_two, dtype=torch.float32, rank=args.rank
)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment