Unverified Commit a11b0f83 authored by Srimanth Agastyaraju's avatar Srimanth Agastyaraju Committed by GitHub
Browse files

Fix: training resume from fp16 for SDXL Consistency Distillation (#6840)



* Fix: training resume from fp16 for lcm distill lora sdxl

* Fix coding quality - run linter

* Fix 1 - shift mixed precision cast before optimizer

* Fix 2 - State dict errors by removing load_lora_into_unet

* Update train_lcm_distill_lora_sdxl.py - Revert default cache dir to None

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 18355105
...@@ -36,7 +36,7 @@ from accelerate.utils import ProjectConfiguration, set_seed ...@@ -36,7 +36,7 @@ from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder from huggingface_hub import create_repo, upload_folder
from packaging import version from packaging import version
from peft import LoraConfig, get_peft_model_state_dict from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import crop from torchvision.transforms.functional import crop
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -52,7 +52,12 @@ from diffusers import ( ...@@ -52,7 +52,12 @@ from diffusers import (
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params, resolve_interpolation_mode from diffusers.training_utils import cast_training_params, resolve_interpolation_mode
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils import (
check_min_version,
convert_state_dict_to_diffusers,
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -858,11 +863,6 @@ def main(args): ...@@ -858,11 +863,6 @@ def main(args):
) )
unet.add_adapter(lora_config) unet.add_adapter(lora_config)
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(unet, dtype=torch.float32)
# Also move the alpha and sigma noise schedules to accelerator.device. # Also move the alpha and sigma noise schedules to accelerator.device.
alpha_schedule = alpha_schedule.to(accelerator.device) alpha_schedule = alpha_schedule.to(accelerator.device)
sigma_schedule = sigma_schedule.to(accelerator.device) sigma_schedule = sigma_schedule.to(accelerator.device)
...@@ -887,13 +887,31 @@ def main(args): ...@@ -887,13 +887,31 @@ def main(args):
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
# load the LoRA into the model # load the LoRA into the model
unet_ = accelerator.unwrap_model(unet) unet_ = accelerator.unwrap_model(unet)
lora_state_dict, network_alphas = StableDiffusionXLPipeline.lora_state_dict(input_dir) lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir)
StableDiffusionXLPipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) unet_state_dict = {
f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
for _ in range(len(models)): for _ in range(len(models)):
# pop models so that they are not loaded again # pop models so that they are not loaded again
models.pop() models.pop()
# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if args.mixed_precision == "fp16":
cast_training_params(unet_, dtype=torch.float32)
accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook) accelerator.register_load_state_pre_hook(load_model_hook)
...@@ -1092,6 +1110,11 @@ def main(args): ...@@ -1092,6 +1110,11 @@ def main(args):
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
) )
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(unet, dtype=torch.float32)
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment