Unverified Commit 6427aa99 authored by takuoko's avatar takuoko Committed by GitHub
Browse files

[Enhance] Add rank in dreambooth (#4112)

add rank in dreambooth
parent 8b18cd8e
...@@ -872,7 +872,9 @@ def main(args): ...@@ -872,7 +872,9 @@ def main(args):
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
) )
module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) module = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank
)
unet_lora_attn_procs[name] = module unet_lora_attn_procs[name] = module
unet_lora_parameters.extend(module.parameters()) unet_lora_parameters.extend(module.parameters())
...@@ -882,7 +884,7 @@ def main(args): ...@@ -882,7 +884,7 @@ def main(args):
# So, instead, we monkey-patch the forward calls of its attention-blocks. # So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder: if args.train_text_encoder:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32) text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32, rank=args.rank)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
...@@ -1364,7 +1366,7 @@ def main(args): ...@@ -1364,7 +1366,7 @@ def main(args):
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
# load attention processors # load attention processors
pipeline.load_lora_weights(args.output_dir) pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.bin")
# run inference # run inference
images = [] images = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment