Unverified Commit 8ce37ab0 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[training] use the lr when using 8bit adam. (#9796)



* use the lr when using 8bit adam.

* remove lr as we pack it in params_to_optimize.

---------
Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
parent 09b8aebd
...@@ -1778,15 +1778,10 @@ def main(args): ...@@ -1778,15 +1778,10 @@ def main(args):
if not args.enable_t5_ti: if not args.enable_t5_ti:
# pure textual inversion - only clip # pure textual inversion - only clip
if pure_textual_inversion: if pure_textual_inversion:
params_to_optimize = [ params_to_optimize = [text_parameters_one_with_lr]
text_parameters_one_with_lr,
]
te_idx = 0 te_idx = 0
else: # regular te training or regular pivotal for clip else: # regular te training or regular pivotal for clip
params_to_optimize = [ params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr]
transformer_parameters_with_lr,
text_parameters_one_with_lr,
]
te_idx = 1 te_idx = 1
elif args.enable_t5_ti: elif args.enable_t5_ti:
# pivotal tuning of clip & t5 # pivotal tuning of clip & t5
...@@ -1809,9 +1804,7 @@ def main(args): ...@@ -1809,9 +1804,7 @@ def main(args):
] ]
te_idx = 1 te_idx = 1
else: else:
params_to_optimize = [ params_to_optimize = [transformer_parameters_with_lr]
transformer_parameters_with_lr,
]
# Optimizer creation # Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
...@@ -1871,7 +1864,6 @@ def main(args): ...@@ -1871,7 +1864,6 @@ def main(args):
params_to_optimize[-1]["lr"] = args.learning_rate params_to_optimize[-1]["lr"] = args.learning_rate
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3, beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
......
...@@ -1358,10 +1358,7 @@ def main(args): ...@@ -1358,10 +1358,7 @@ def main(args):
else args.adam_weight_decay, else args.adam_weight_decay,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
} }
params_to_optimize = [ params_to_optimize = [unet_lora_parameters_with_lr, text_lora_parameters_one_with_lr]
unet_lora_parameters_with_lr,
text_lora_parameters_one_with_lr,
]
else: else:
params_to_optimize = [unet_lora_parameters_with_lr] params_to_optimize = [unet_lora_parameters_with_lr]
...@@ -1423,7 +1420,6 @@ def main(args): ...@@ -1423,7 +1420,6 @@ def main(args):
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3, beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
......
...@@ -1794,7 +1794,6 @@ def main(args): ...@@ -1794,7 +1794,6 @@ def main(args):
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3, beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
......
...@@ -947,7 +947,6 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): ...@@ -947,7 +947,6 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3, beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
......
...@@ -969,7 +969,6 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): ...@@ -969,7 +969,6 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3, beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
......
...@@ -1226,10 +1226,7 @@ def main(args): ...@@ -1226,10 +1226,7 @@ def main(args):
"weight_decay": args.adam_weight_decay_text_encoder, "weight_decay": args.adam_weight_decay_text_encoder,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
} }
params_to_optimize = [ params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr]
transformer_parameters_with_lr,
text_parameters_one_with_lr,
]
else: else:
params_to_optimize = [transformer_parameters_with_lr] params_to_optimize = [transformer_parameters_with_lr]
...@@ -1291,7 +1288,6 @@ def main(args): ...@@ -1291,7 +1288,6 @@ def main(args):
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3, beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
......
...@@ -1335,10 +1335,7 @@ def main(args): ...@@ -1335,10 +1335,7 @@ def main(args):
"weight_decay": args.adam_weight_decay_text_encoder, "weight_decay": args.adam_weight_decay_text_encoder,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
} }
params_to_optimize = [ params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr]
transformer_parameters_with_lr,
text_parameters_one_with_lr,
]
else: else:
params_to_optimize = [transformer_parameters_with_lr] params_to_optimize = [transformer_parameters_with_lr]
...@@ -1400,7 +1397,6 @@ def main(args): ...@@ -1400,7 +1397,6 @@ def main(args):
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3, beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
......
...@@ -1468,7 +1468,6 @@ def main(args): ...@@ -1468,7 +1468,6 @@ def main(args):
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3, beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
......
...@@ -1402,7 +1402,6 @@ def main(args): ...@@ -1402,7 +1402,6 @@ def main(args):
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3, beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
......
...@@ -1328,7 +1328,6 @@ def main(args): ...@@ -1328,7 +1328,6 @@ def main(args):
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3, beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
......
...@@ -1475,7 +1475,6 @@ def main(args): ...@@ -1475,7 +1475,6 @@ def main(args):
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3, beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment