"docs/vscode:/vscode.git/clone" did not exist on "1bdda8cb6e111903aa29d75dc2f33498f5df533a"
Unverified Commit d720b213 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Advanced LoRA v1.5] fix: gradient unscaling problem (#7018)



fix: gradient unscaling problem
Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
parent 9cc96a64
...@@ -39,7 +39,7 @@ from accelerate.logging import get_logger ...@@ -39,7 +39,7 @@ from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder from huggingface_hub import create_repo, upload_folder
from packaging import version from packaging import version
from peft import LoraConfig from peft import LoraConfig, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict from peft.utils import get_peft_model_state_dict
from PIL import Image from PIL import Image
from PIL.ImageOps import exif_transpose from PIL.ImageOps import exif_transpose
...@@ -59,12 +59,13 @@ from diffusers import ( ...@@ -59,12 +59,13 @@ from diffusers import (
) )
from diffusers.loaders import StableDiffusionLoraLoaderMixin from diffusers.loaders import StableDiffusionLoraLoaderMixin
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
from diffusers.utils import ( from diffusers.utils import (
check_min_version, check_min_version,
convert_all_state_dict_to_peft, convert_all_state_dict_to_peft,
convert_state_dict_to_diffusers, convert_state_dict_to_diffusers,
convert_state_dict_to_kohya, convert_state_dict_to_kohya,
convert_unet_state_dict_to_peft,
is_wandb_available, is_wandb_available,
) )
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
...@@ -1319,6 +1320,37 @@ def main(args): ...@@ -1319,6 +1320,37 @@ def main(args):
else: else:
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
if args.train_text_encoder:
# Do we need to call `scale_lora_layers()` here?
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
_set_state_dict_into_text_encoder(
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_
)
# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if args.mixed_precision == "fp16":
models = [unet_]
if args.train_text_encoder:
models.extend([text_encoder_one_])
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models)
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
StableDiffusionLoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) StableDiffusionLoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment