"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "3fa2414866f69a5ba541bb27a94963df01d9bd06"
Unverified Commit 1ac07d8a authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Training examples] Follow up of #6306 (#6346)

* add to dreambooth lora.

* add: t2i lora.

* add: sdxl t2i lora.

* style

* lcm lora sdxl.

* unwrap

* fix: enable_adapters().
parent 1fff5277
...@@ -51,7 +51,7 @@ from diffusers import ( ...@@ -51,7 +51,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -113,7 +113,7 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin ...@@ -113,7 +113,7 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin
if unet is None: if unet is None:
raise ValueError("Must provide a `unet` when doing intermediate validation.") raise ValueError("Must provide a `unet` when doing intermediate validation.")
unet = accelerator.unwrap_model(unet) unet = accelerator.unwrap_model(unet)
state_dict = get_peft_model_state_dict(unet) state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
to_load = state_dict to_load = state_dict
else: else:
to_load = args.output_dir to_load = args.output_dir
...@@ -819,7 +819,7 @@ def main(args): ...@@ -819,7 +819,7 @@ def main(args):
unet_ = accelerator.unwrap_model(unet) unet_ = accelerator.unwrap_model(unet)
# also save the checkpoints in native `diffusers` format so that it can be easily # also save the checkpoints in native `diffusers` format so that it can be easily
# be independently loaded via `load_lora_weights()`. # be independently loaded via `load_lora_weights()`.
state_dict = get_peft_model_state_dict(unet_) state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet_))
StableDiffusionXLPipeline.save_lora_weights(output_dir, unet_lora_layers=state_dict) StableDiffusionXLPipeline.save_lora_weights(output_dir, unet_lora_layers=state_dict)
for _, model in enumerate(models): for _, model in enumerate(models):
...@@ -1184,7 +1184,7 @@ def main(args): ...@@ -1184,7 +1184,7 @@ def main(args):
# solver timestep. # solver timestep.
# With the adapters disabled, the `unet` is the regular teacher model. # With the adapters disabled, the `unet` is the regular teacher model.
unet.disable_adapters() accelerator.unwrap_model(unet).disable_adapters()
with torch.no_grad(): with torch.no_grad():
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = unet( cond_teacher_output = unet(
...@@ -1248,7 +1248,7 @@ def main(args): ...@@ -1248,7 +1248,7 @@ def main(args):
x_prev = solver.ddim_step(pred_x0, pred_noise, index).to(unet.dtype) x_prev = solver.ddim_step(pred_x0, pred_noise, index).to(unet.dtype)
# re-enable unet adapters to turn the `unet` into a student unet. # re-enable unet adapters to turn the `unet` into a student unet.
unet.enable_adapters() accelerator.unwrap_model(unet).enable_adapters()
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation. # Note that we do not use a separate target network for LCM-LoRA distillation.
...@@ -1332,7 +1332,7 @@ def main(args): ...@@ -1332,7 +1332,7 @@ def main(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet) unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unet) unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
StableDiffusionXLPipeline.save_lora_weights(args.output_dir, unet_lora_layers=unet_lora_state_dict) StableDiffusionXLPipeline.save_lora_weights(args.output_dir, unet_lora_layers=unet_lora_state_dict)
if args.push_to_hub: if args.push_to_hub:
......
...@@ -54,7 +54,7 @@ from diffusers import ( ...@@ -54,7 +54,7 @@ from diffusers import (
) )
from diffusers.loaders import LoraLoaderMixin from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -853,9 +853,11 @@ def main(args): ...@@ -853,9 +853,11 @@ def main(args):
for model in models: for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))): if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model) unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model) text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
else: else:
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
...@@ -1285,11 +1287,11 @@ def main(args): ...@@ -1285,11 +1287,11 @@ def main(args):
unet = accelerator.unwrap_model(unet) unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32) unet = unet.to(torch.float32)
unet_lora_state_dict = get_peft_model_state_dict(unet) unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
if args.train_text_encoder: if args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder) text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder_state_dict = get_peft_model_state_dict(text_encoder) text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder))
else: else:
text_encoder_state_dict = None text_encoder_state_dict = None
......
...@@ -44,7 +44,7 @@ import diffusers ...@@ -44,7 +44,7 @@ import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -809,7 +809,9 @@ def main(): ...@@ -809,7 +809,9 @@ def main():
accelerator.save_state(save_path) accelerator.save_state(save_path)
unwrapped_unet = accelerator.unwrap_model(unet) unwrapped_unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet) unet_lora_state_dict = convert_state_dict_to_diffusers(
get_peft_model_state_dict(unwrapped_unet)
)
StableDiffusionPipeline.save_lora_weights( StableDiffusionPipeline.save_lora_weights(
save_directory=save_path, save_directory=save_path,
...@@ -876,7 +878,7 @@ def main(): ...@@ -876,7 +878,7 @@ def main():
unet = unet.to(torch.float32) unet = unet.to(torch.float32)
unwrapped_unet = accelerator.unwrap_model(unet) unwrapped_unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet) unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
StableDiffusionPipeline.save_lora_weights( StableDiffusionPipeline.save_lora_weights(
save_directory=args.output_dir, save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict, unet_lora_layers=unet_lora_state_dict,
......
...@@ -52,7 +52,7 @@ from diffusers import ( ...@@ -52,7 +52,7 @@ from diffusers import (
from diffusers.loaders import LoraLoaderMixin from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -651,11 +651,15 @@ def main(args): ...@@ -651,11 +651,15 @@ def main(args):
for model in models: for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))): if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model) unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
else: else:
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
...@@ -1160,14 +1164,14 @@ def main(args): ...@@ -1160,14 +1164,14 @@ def main(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet) unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unet) unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
if args.train_text_encoder: if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one) text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_two = accelerator.unwrap_model(text_encoder_two) text_encoder_two = accelerator.unwrap_model(text_encoder_two)
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one) text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one))
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two) text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two))
else: else:
text_encoder_lora_layers = None text_encoder_lora_layers = None
text_encoder_2_lora_layers = None text_encoder_2_lora_layers = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment