Unverified Commit 35f45ecd authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[Advanced dreambooth lora] adjustments to align with canonical script (#8406)



* minor changes

* minor changes

* minor changes

* minor changes

* minor changes

* minor changes

* minor changes

* fix

* fix

* aligning with blora script

* aligning with blora script

* aligning with blora script

* aligning with blora script

* aligning with blora script

* remove prints

* style

* default val

* license

* move save_model_card to outside push_to_hub

* Update train_dreambooth_lora_sdxl_advanced.py

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent d5dd8df3
...@@ -31,8 +31,6 @@ from typing import List, Optional ...@@ -31,8 +31,6 @@ from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
# imports of the TokenEmbeddingsHandler class
import torch.utils.checkpoint import torch.utils.checkpoint
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
...@@ -77,6 +75,9 @@ from diffusers.utils.import_utils import is_xformers_available ...@@ -77,6 +75,9 @@ from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0.dev0")
...@@ -101,12 +102,12 @@ def save_model_card( ...@@ -101,12 +102,12 @@ def save_model_card(
repo_id: str, repo_id: str,
use_dora: bool, use_dora: bool,
images=None, images=None,
base_model=str, base_model: str = None,
train_text_encoder=False, train_text_encoder=False,
train_text_encoder_ti=False, train_text_encoder_ti=False,
token_abstraction_dict=None, token_abstraction_dict=None,
instance_prompt=str, instance_prompt: str = None,
validation_prompt=str, validation_prompt: str = None,
repo_folder=None, repo_folder=None,
vae_path=None, vae_path=None,
): ):
...@@ -135,6 +136,14 @@ def save_model_card( ...@@ -135,6 +136,14 @@ def save_model_card(
diffusers_imports_pivotal = "" diffusers_imports_pivotal = ""
diffusers_example_pivotal = "" diffusers_example_pivotal = ""
webui_example_pivotal = "" webui_example_pivotal = ""
license = ""
if "playground" in base_model:
license = """\n
## License
Please adhere to the licensing terms as described [here](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md).
"""
if train_text_encoder_ti: if train_text_encoder_ti:
trigger_str = ( trigger_str = (
"To trigger image generation of trained concept(or concepts) replace each concept identifier " "To trigger image generation of trained concept(or concepts) replace each concept identifier "
...@@ -223,11 +232,75 @@ Pivotal tuning was enabled: {train_text_encoder_ti}. ...@@ -223,11 +232,75 @@ Pivotal tuning was enabled: {train_text_encoder_ti}.
Special VAE used for training: {vae_path}. Special VAE used for training: {vae_path}.
{license}
""" """
with open(os.path.join(repo_folder, "README.md"), "w") as f: with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card) f.write(yaml + model_card)
def log_validation(
pipeline,
args,
accelerator,
pipeline_args,
epoch,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if not args.do_edm_style_training:
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
phase_name: [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
]
}
)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
return images
def import_model_class_from_model_name_or_path( def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
): ):
...@@ -390,6 +463,7 @@ def parse_args(input_args=None): ...@@ -390,6 +463,7 @@ def parse_args(input_args=None):
) )
parser.add_argument( parser.add_argument(
"--do_edm_style_training", "--do_edm_style_training",
default=False,
action="store_true", action="store_true",
help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.", help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.",
) )
...@@ -571,7 +645,7 @@ def parse_args(input_args=None): ...@@ -571,7 +645,7 @@ def parse_args(input_args=None):
parser.add_argument( parser.add_argument(
"--optimizer", "--optimizer",
type=str, type=str,
default="adamW", default="AdamW",
help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
) )
...@@ -906,11 +980,6 @@ class DreamBoothDataset(Dataset): ...@@ -906,11 +980,6 @@ class DreamBoothDataset(Dataset):
instance_data_root, instance_data_root,
instance_prompt, instance_prompt,
class_prompt, class_prompt,
dataset_name,
dataset_config_name,
cache_dir,
image_column,
caption_column,
train_text_encoder_ti, train_text_encoder_ti,
class_data_root=None, class_data_root=None,
class_num=None, class_num=None,
...@@ -929,7 +998,7 @@ class DreamBoothDataset(Dataset): ...@@ -929,7 +998,7 @@ class DreamBoothDataset(Dataset):
self.train_text_encoder_ti = train_text_encoder_ti self.train_text_encoder_ti = train_text_encoder_ti
# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
# we load the training data using load_dataset # we load the training data using load_dataset
if dataset_name is not None: if args.dataset_name is not None:
try: try:
from datasets import load_dataset from datasets import load_dataset
except ImportError: except ImportError:
...@@ -942,25 +1011,26 @@ class DreamBoothDataset(Dataset): ...@@ -942,25 +1011,26 @@ class DreamBoothDataset(Dataset):
# See more about loading custom images at # See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
dataset = load_dataset( dataset = load_dataset(
dataset_name, args.dataset_name,
dataset_config_name, args.dataset_config_name,
cache_dir=cache_dir, cache_dir=args.cache_dir,
) )
# Preprocessing the datasets. # Preprocessing the datasets.
column_names = dataset["train"].column_names column_names = dataset["train"].column_names
# 6. Get the column names for input/target. # 6. Get the column names for input/target.
if image_column is None: if args.image_column is None:
image_column = column_names[0] image_column = column_names[0]
logger.info(f"image column defaulting to {image_column}") logger.info(f"image column defaulting to {image_column}")
else: else:
image_column = args.image_column
if image_column not in column_names: if image_column not in column_names:
raise ValueError( raise ValueError(
f"`--image_column` value '{image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
) )
instance_images = dataset["train"][image_column] instance_images = dataset["train"][image_column]
if caption_column is None: if args.caption_column is None:
logger.info( logger.info(
"No caption column provided, defaulting to instance_prompt for all images. If your dataset " "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
"contains captions/prompts for the images, make sure to specify the " "contains captions/prompts for the images, make sure to specify the "
...@@ -968,11 +1038,11 @@ class DreamBoothDataset(Dataset): ...@@ -968,11 +1038,11 @@ class DreamBoothDataset(Dataset):
) )
self.custom_instance_prompts = None self.custom_instance_prompts = None
else: else:
if caption_column not in column_names: if args.caption_column not in column_names:
raise ValueError( raise ValueError(
f"`--caption_column` value '{caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
) )
custom_instance_prompts = dataset["train"][caption_column] custom_instance_prompts = dataset["train"][args.caption_column]
# create final list of captions according to --repeats # create final list of captions according to --repeats
self.custom_instance_prompts = [] self.custom_instance_prompts = []
for caption in custom_instance_prompts: for caption in custom_instance_prompts:
...@@ -1178,13 +1248,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): ...@@ -1178,13 +1248,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
text_input_ids = text_input_ids_list[i] text_input_ids = text_input_ids_list[i]
prompt_embeds = text_encoder( prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device), text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
output_hidden_states=True,
) )
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds[-1][-2]
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds) prompt_embeds_list.append(prompt_embeds)
...@@ -1200,9 +1269,16 @@ def main(args): ...@@ -1200,9 +1269,16 @@ def main(args):
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
" Please use `huggingface-cli login` to authenticate with the Hub." " Please use `huggingface-cli login` to authenticate with the Hub."
) )
if args.do_edm_style_training and args.snr_gamma is not None: if args.do_edm_style_training and args.snr_gamma is not None:
raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.") raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
...@@ -1215,10 +1291,13 @@ def main(args): ...@@ -1215,10 +1291,13 @@ def main(args):
kwargs_handlers=[kwargs], kwargs_handlers=[kwargs],
) )
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb": if args.report_to == "wandb":
if not is_wandb_available(): if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
...@@ -1246,7 +1325,8 @@ def main(args): ...@@ -1246,7 +1325,8 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir())) cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images: if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
if args.prior_generation_precision == "fp32": if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32 torch_dtype = torch.float32
elif args.prior_generation_precision == "fp16": elif args.prior_generation_precision == "fp16":
...@@ -1404,6 +1484,12 @@ def main(args): ...@@ -1404,6 +1484,12 @@ def main(args):
elif accelerator.mixed_precision == "bf16": elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16 weight_dtype = torch.bfloat16
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
# Move unet, vae and text_encoder to device and cast to weight_dtype # Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype)
...@@ -1508,15 +1594,13 @@ def main(args): ...@@ -1508,15 +1594,13 @@ def main(args):
if isinstance(model, type(unwrap_model(unet))): if isinstance(model, type(unwrap_model(unet))):
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(unwrap_model(text_encoder_one))): elif isinstance(model, type(unwrap_model(text_encoder_one))):
if args.train_text_encoder: text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model)
get_peft_model_state_dict(model) )
)
elif isinstance(model, type(unwrap_model(text_encoder_two))): elif isinstance(model, type(unwrap_model(text_encoder_two))):
if args.train_text_encoder: text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model)
get_peft_model_state_dict(model) )
)
else: else:
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
...@@ -1564,6 +1648,7 @@ def main(args): ...@@ -1564,6 +1648,7 @@ def main(args):
) )
if args.train_text_encoder: if args.train_text_encoder:
# Do we need to call `scale_lora_layers()` here?
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
_set_state_dict_into_text_encoder( _set_state_dict_into_text_encoder(
...@@ -1578,14 +1663,14 @@ def main(args): ...@@ -1578,14 +1663,14 @@ def main(args):
if args.train_text_encoder: if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_]) models.extend([text_encoder_one_, text_encoder_two_])
# only upcast trainable parameters (LoRA) into fp32 # only upcast trainable parameters (LoRA) into fp32
cast_training_params(models) cast_training_params(models)
accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook) accelerator.register_load_state_pre_hook(load_model_hook)
# Enable TF32 for faster training on Ampere GPUs, # Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32: if args.allow_tf32 and torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr: if args.scale_lr:
...@@ -1711,12 +1796,7 @@ def main(args): ...@@ -1711,12 +1796,7 @@ def main(args):
instance_data_root=args.instance_data_dir, instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt, instance_prompt=args.instance_prompt,
class_prompt=args.class_prompt, class_prompt=args.class_prompt,
dataset_name=args.dataset_name,
dataset_config_name=args.dataset_config_name,
cache_dir=args.cache_dir,
image_column=args.image_column,
train_text_encoder_ti=args.train_text_encoder_ti, train_text_encoder_ti=args.train_text_encoder_ti,
caption_column=args.caption_column,
class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_data_root=args.class_data_dir if args.with_prior_preservation else None,
token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None, token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
class_num=args.num_class_images, class_num=args.num_class_images,
...@@ -1740,8 +1820,6 @@ def main(args): ...@@ -1740,8 +1820,6 @@ def main(args):
def compute_time_ids(crops_coords_top_left, original_size=None): def compute_time_ids(crops_coords_top_left, original_size=None):
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
if original_size is None:
original_size = (args.resolution, args.resolution)
target_size = (args.resolution, args.resolution) target_size = (args.resolution, args.resolution)
add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids]) add_time_ids = torch.tensor([add_time_ids])
...@@ -1778,7 +1856,8 @@ def main(args): ...@@ -1778,7 +1856,8 @@ def main(args):
if freeze_text_encoder and not train_dataset.custom_instance_prompts: if freeze_text_encoder and not train_dataset.custom_instance_prompts:
del tokenizers, text_encoders del tokenizers, text_encoders
gc.collect() gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available():
torch.cuda.empty_cache()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't # pack the statically computed variables appropriately here. This is so that we don't
...@@ -1946,8 +2025,8 @@ def main(args): ...@@ -1946,8 +2025,8 @@ def main(args):
text_encoder_two.train() text_encoder_two.train()
# set top parameter requires_grad = True for gradient checkpointing works # set top parameter requires_grad = True for gradient checkpointing works
if args.train_text_encoder: if args.train_text_encoder:
text_encoder_one.text_model.embeddings.requires_grad_(True) accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
text_encoder_two.text_model.embeddings.requires_grad_(True) accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
if pivoted: if pivoted:
...@@ -2040,7 +2119,6 @@ def main(args): ...@@ -2040,7 +2119,6 @@ def main(args):
if freeze_text_encoder: if freeze_text_encoder:
unet_added_conditions = { unet_added_conditions = {
"time_ids": add_time_ids, "time_ids": add_time_ids,
# "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1), "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
} }
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
...@@ -2220,10 +2298,6 @@ def main(args): ...@@ -2220,10 +2298,6 @@ def main(args):
if accelerator.is_main_process: if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0: if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline # create pipeline
if freeze_text_encoder: if freeze_text_encoder:
text_encoder_one = text_encoder_cls_one.from_pretrained( text_encoder_one = text_encoder_cls_one.from_pretrained(
...@@ -2250,70 +2324,29 @@ def main(args): ...@@ -2250,70 +2324,29 @@ def main(args):
variant=args.variant, variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if not args.do_edm_style_training:
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, **scheduler_args
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt} pipeline_args = {"prompt": args.validation_prompt}
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
]
for tracker in accelerator.trackers: images = log_validation(
if tracker.name == "tensorboard": pipeline,
np_images = np.stack([np.asarray(img) for img in images]) args,
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") accelerator,
if tracker.name == "wandb": pipeline_args,
tracker.log( epoch,
{ )
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
del pipeline
torch.cuda.empty_cache()
# Save the lora layers # Save the lora layers
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet) unet = unwrap_model(unet)
unet = unet.to(torch.float32) unet = unet.to(torch.float32)
unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
if args.train_text_encoder: if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one) text_encoder_one = unwrap_model(text_encoder_one)
text_encoder_lora_layers = convert_state_dict_to_diffusers( text_encoder_lora_layers = convert_state_dict_to_diffusers(
get_peft_model_state_dict(text_encoder_one.to(torch.float32)) get_peft_model_state_dict(text_encoder_one.to(torch.float32))
) )
text_encoder_two = accelerator.unwrap_model(text_encoder_two) text_encoder_two = unwrap_model(text_encoder_two)
text_encoder_2_lora_layers = convert_state_dict_to_diffusers( text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
get_peft_model_state_dict(text_encoder_two.to(torch.float32)) get_peft_model_state_dict(text_encoder_two.to(torch.float32))
) )
...@@ -2332,85 +2365,39 @@ def main(args): ...@@ -2332,85 +2365,39 @@ def main(args):
embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors" embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors"
embedding_handler.save_embeddings(embeddings_path) embedding_handler.save_embeddings(embeddings_path)
# Final inference
# Load previous pipeline
vae = AutoencoderKL.from_pretrained(
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
# load attention processors
pipeline.load_lora_weights(args.output_dir)
# run inference
images = [] images = []
if args.validation_prompt and args.num_validation_images > 0: if args.validation_prompt and args.num_validation_images > 0:
# Final inference pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25}
# Load previous pipeline images = log_validation(
vae = AutoencoderKL.from_pretrained( pipeline,
vae_path, args,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, accelerator,
revision=args.revision, pipeline_args,
variant=args.variant, epoch,
torch_dtype=weight_dtype, is_final_validation=True,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
) )
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if not args.do_edm_style_training:
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, **scheduler_args
)
# load attention processors
pipeline.load_lora_weights(args.output_dir)
# load new tokens
if args.train_text_encoder_ti:
state_dict = load_file(embeddings_path)
all_new_tokens = []
for key, value in token_abstraction_dict.items():
all_new_tokens.extend(value)
pipeline.load_textual_inversion(
state_dict["clip_l"],
token=all_new_tokens,
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer,
)
pipeline.load_textual_inversion(
state_dict["clip_g"],
token=all_new_tokens,
text_encoder=pipeline.text_encoder_2,
tokenizer=pipeline.tokenizer_2,
)
# run inference
pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
for _ in range(args.num_validation_images)
]
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"test": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
# Convert to WebUI format # Convert to WebUI format
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors") lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict) peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
...@@ -2430,6 +2417,7 @@ def main(args): ...@@ -2430,6 +2417,7 @@ def main(args):
repo_folder=args.output_dir, repo_folder=args.output_dir,
vae_path=args.pretrained_vae_model_name_or_path, vae_path=args.pretrained_vae_model_name_or_path,
) )
if args.push_to_hub: if args.push_to_hub:
upload_folder( upload_folder(
repo_id=repo_id, repo_id=repo_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment