Unverified Commit a080f0d3 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Training Utils] create a utility for casting the lora params during training. (#6553)

create a utility for casting the lora params during training.
parent 79df5038
...@@ -51,7 +51,7 @@ from diffusers import ( ...@@ -51,7 +51,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import resolve_interpolation_mode from diffusers.training_utils import cast_training_params, resolve_interpolation_mode
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -860,10 +860,8 @@ def main(args): ...@@ -860,10 +860,8 @@ def main(args):
# Make sure the trainable params are in float32. # Make sure the trainable params are in float32.
if args.mixed_precision == "fp16": if args.mixed_precision == "fp16":
for param in unet.parameters(): # only upcast trainable parameters (LoRA) into fp32
# only upcast trainable parameters (LoRA) into fp32 cast_training_params(unet, dtype=torch.float32)
if param.requires_grad:
param.data = param.to(torch.float32)
# Also move the alpha and sigma noise schedules to accelerator.device. # Also move the alpha and sigma noise schedules to accelerator.device.
alpha_schedule = alpha_schedule.to(accelerator.device) alpha_schedule = alpha_schedule.to(accelerator.device)
......
...@@ -53,7 +53,7 @@ from diffusers import ( ...@@ -53,7 +53,7 @@ from diffusers import (
) )
from diffusers.loaders import LoraLoaderMixin from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import _set_state_dict_into_text_encoder, compute_snr from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
from diffusers.utils import ( from diffusers.utils import (
check_min_version, check_min_version,
convert_state_dict_to_diffusers, convert_state_dict_to_diffusers,
...@@ -1086,11 +1086,8 @@ def main(args): ...@@ -1086,11 +1086,8 @@ def main(args):
models = [unet_] models = [unet_]
if args.train_text_encoder: if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_]) models.extend([text_encoder_one_, text_encoder_two_])
for model in models: # only upcast trainable parameters (LoRA) into fp32
for param in model.parameters(): cast_training_params(models)
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook) accelerator.register_load_state_pre_hook(load_model_hook)
...@@ -1110,11 +1107,9 @@ def main(args): ...@@ -1110,11 +1107,9 @@ def main(args):
models = [unet] models = [unet]
if args.train_text_encoder: if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two]) models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters(): # only upcast trainable parameters (LoRA) into fp32
# only upcast trainable parameters (LoRA) into fp32 cast_training_params(models, dtype=torch.float32)
if param.requires_grad:
param.data = param.to(torch.float32)
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
......
...@@ -43,7 +43,7 @@ from transformers import CLIPTextModel, CLIPTokenizer ...@@ -43,7 +43,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
import diffusers import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr from diffusers.training_utils import cast_training_params, compute_snr
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -466,10 +466,8 @@ def main(): ...@@ -466,10 +466,8 @@ def main():
# Add adapter and make sure the trainable params are in float32. # Add adapter and make sure the trainable params are in float32.
unet.add_adapter(unet_lora_config) unet.add_adapter(unet_lora_config)
if args.mixed_precision == "fp16": if args.mixed_precision == "fp16":
for param in unet.parameters(): # only upcast trainable parameters (LoRA) into fp32
# only upcast trainable parameters (LoRA) into fp32 cast_training_params(unet, dtype=torch.float32)
if param.requires_grad:
param.data = param.to(torch.float32)
if args.enable_xformers_memory_efficient_attention: if args.enable_xformers_memory_efficient_attention:
if is_xformers_available(): if is_xformers_available():
......
...@@ -51,7 +51,7 @@ from diffusers import ( ...@@ -51,7 +51,7 @@ from diffusers import (
) )
from diffusers.loaders import LoraLoaderMixin from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr from diffusers.training_utils import cast_training_params, compute_snr
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -634,11 +634,8 @@ def main(args): ...@@ -634,11 +634,8 @@ def main(args):
models = [unet] models = [unet]
if args.train_text_encoder: if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two]) models.extend([text_encoder_one, text_encoder_two])
for model in models: # only upcast trainable parameters (LoRA) into fp32
for param in model.parameters(): cast_training_params(models, dtype=torch.float32)
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
......
import contextlib import contextlib
import copy import copy
import random import random
from typing import Any, Dict, Iterable, Optional, Union from typing import Any, Dict, Iterable, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -121,6 +121,16 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: ...@@ -121,6 +121,16 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
return lora_state_dict return lora_state_dict
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
if not isinstance(model, list):
model = [model]
for m in model:
for param in m.parameters():
# only upcast trainable parameters into fp32
if param.requires_grad:
param.data = param.to(dtype)
def _set_state_dict_into_text_encoder( def _set_state_dict_into_text_encoder(
lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment