Unverified Commit cdf2ae8a authored by takuoko's avatar takuoko Committed by GitHub
Browse files

[Enhance] Add LoRA rank args in train_text_to_image_lora (#3866)

* add rank args in lora finetune

* del network_alpha
parent 49949f32
...@@ -343,6 +343,12 @@ def parse_args(): ...@@ -343,6 +343,12 @@ def parse_args():
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
) )
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
parser.add_argument(
"--rank",
type=int,
default=4,
help=("The dimension of the LoRA update matrices."),
)
args = parser.parse_args() args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
...@@ -464,7 +470,11 @@ def main(): ...@@ -464,7 +470,11 @@ def main():
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id] hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=args.rank,
)
unet.set_attn_processor(lora_attn_procs) unet.set_attn_processor(lora_attn_procs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment