Unverified Commit 09b8aebd authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[training] fixes to the quantization training script and add AdEMAMix...

[training] fixes to the quantization training script and add AdEMAMix optimizer as an option (#9806)

* fixes

* more fixes.
parent c1d4a0dd
...@@ -349,7 +349,7 @@ def parse_args(input_args=None): ...@@ -349,7 +349,7 @@ def parse_args(input_args=None):
"--optimizer", "--optimizer",
type=str, type=str,
default="AdamW", default="AdamW",
help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), choices=["AdamW", "Prodigy", "AdEMAMix"],
) )
parser.add_argument( parser.add_argument(
...@@ -357,6 +357,11 @@ def parse_args(input_args=None): ...@@ -357,6 +357,11 @@ def parse_args(input_args=None):
action="store_true", action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
) )
parser.add_argument(
"--use_8bit_ademamix",
action="store_true",
help="Whether or not to use 8-bit AdEMAMix from bitsandbytes.",
)
parser.add_argument( parser.add_argument(
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
...@@ -820,16 +825,15 @@ def main(args): ...@@ -820,16 +825,15 @@ def main(args):
params_to_optimize = [transformer_parameters_with_lr] params_to_optimize = [transformer_parameters_with_lr]
# Optimizer creation # Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
logger.warning( logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
"Defaulting to adamW" f"set to {args.optimizer.lower()}"
) )
args.optimizer = "adamw"
if args.use_8bit_adam and not args.optimizer.lower() == "adamw": if args.use_8bit_ademamix and not args.optimizer.lower() == "ademamix":
logger.warning( logger.warning(
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " f"use_8bit_ademamix is ignored when optimizer is not set to 'AdEMAMix'. Optimizer was "
f"set to {args.optimizer.lower()}" f"set to {args.optimizer.lower()}"
) )
...@@ -853,6 +857,20 @@ def main(args): ...@@ -853,6 +857,20 @@ def main(args):
eps=args.adam_epsilon, eps=args.adam_epsilon,
) )
elif args.optimizer.lower() == "ademamix":
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use AdEMAMix (or its 8bit variant), please install the bitsandbytes library: `pip install -U bitsandbytes`."
)
if args.use_8bit_ademamix:
optimizer_class = bnb.optim.AdEMAMix8bit
else:
optimizer_class = bnb.optim.AdEMAMix
optimizer = optimizer_class(params_to_optimize)
if args.optimizer.lower() == "prodigy": if args.optimizer.lower() == "prodigy":
try: try:
import prodigyopt import prodigyopt
...@@ -868,7 +886,6 @@ def main(args): ...@@ -868,7 +886,6 @@ def main(args):
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3, beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
...@@ -1020,12 +1037,12 @@ def main(args): ...@@ -1020,12 +1037,12 @@ def main(args):
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype) model_input = model_input.to(dtype=weight_dtype)
vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
latent_image_ids = FluxPipeline._prepare_latent_image_ids( latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0], model_input.shape[0],
model_input.shape[2], model_input.shape[2] // 2,
model_input.shape[3], model_input.shape[3] // 2,
accelerator.device, accelerator.device,
weight_dtype, weight_dtype,
) )
...@@ -1059,7 +1076,7 @@ def main(args): ...@@ -1059,7 +1076,7 @@ def main(args):
) )
# handle guidance # handle guidance
if transformer.config.guidance_embeds: if unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0]) guidance = guidance.expand(model_input.shape[0])
else: else:
...@@ -1082,8 +1099,8 @@ def main(args): ...@@ -1082,8 +1099,8 @@ def main(args):
)[0] )[0]
model_pred = FluxPipeline._unpack_latents( model_pred = FluxPipeline._unpack_latents(
model_pred, model_pred,
height=int(model_input.shape[2] * vae_scale_factor / 2), height=model_input.shape[2] * vae_scale_factor,
width=int(model_input.shape[3] * vae_scale_factor / 2), width=model_input.shape[3] * vae_scale_factor,
vae_scale_factor=vae_scale_factor, vae_scale_factor=vae_scale_factor,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment