Unverified Commit 01782c22 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Wuerstchen] Adapt lora training example scripts to use PEFT (#5959)

* Adapt lora example scripts to use PEFT

* add to_out.0
parent d63a498c
......@@ -5,3 +5,4 @@ wandb
huggingface-cli
bitsandbytes
deepspeed
peft>=0.6.0
......@@ -31,14 +31,14 @@ from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, hf_hub_download, upload_folder
from modeling_efficient_net_encoder import EfficientNetEncoder
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from torchvision import transforms
from tqdm import tqdm
from transformers import CLIPTextModel, PreTrainedTokenizerFast
from transformers.utils import ContextManagers
from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler, WuerstchenPriorPipeline
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
......@@ -139,17 +139,17 @@ More information on all the CLI arguments and the environment are available on y
f.write(yaml + model_card)
def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator, weight_dtype, epoch):
def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch):
logger.info("Running validation... ")
pipeline = AutoPipelineForText2Image.from_pretrained(
args.pretrained_decoder_model_name_or_path,
prior=accelerator.unwrap_model(prior),
prior_text_encoder=accelerator.unwrap_model(text_encoder),
prior_tokenizer=tokenizer,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.prior_prior.set_attn_processor(attn_processors)
pipeline.set_progress_bar_config(disable=True)
if args.seed is None:
......@@ -159,7 +159,7 @@ def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator,
images = []
for i in range(len(args.validation_prompts)):
with torch.autocast("cuda"):
with torch.cuda.amp.autocast():
image = pipeline(
args.validation_prompts[i],
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
......@@ -167,7 +167,6 @@ def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator,
height=args.resolution,
width=args.resolution,
).images[0]
images.append(image)
for tracker in accelerator.trackers:
......@@ -527,11 +526,50 @@ def main():
prior.to(accelerator.device, dtype=weight_dtype)
# lora attn processor
lora_attn_procs = {}
for name in prior.attn_processors.keys():
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=prior.config["c"], rank=args.rank)
prior.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(prior.attn_processors)
prior_lora_config = LoraConfig(
r=args.rank, target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"]
)
prior.add_adapter(prior_lora_config)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
prior_lora_layers_to_save = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(prior))):
prior_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
WuerstchenPriorPipeline.save_lora_weights(
output_dir,
unet_lora_layers=prior_lora_layers_to_save,
)
def load_model_hook(models, input_dir):
prior_ = None
while len(models) > 0:
model = models.pop()
if isinstance(model, type(accelerator.unwrap_model(prior))):
prior_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alphas = WuerstchenPriorPipeline.lora_state_dict(input_dir)
WuerstchenPriorPipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=prior_)
WuerstchenPriorPipeline.load_lora_into_text_encoder(
lora_state_dict,
network_alphas=network_alphas,
)
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
......@@ -547,8 +585,9 @@ def main():
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
params_to_optimize = list(filter(lambda p: p.requires_grad, prior.parameters()))
optimizer = optimizer_cls(
lora_layers.parameters(),
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
......@@ -674,8 +713,8 @@ def main():
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
lora_layers, optimizer, train_dataloader, lr_scheduler
prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
prior, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
......@@ -782,7 +821,7 @@ def main():
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(lora_layers.parameters(), args.max_grad_norm)
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
......@@ -828,17 +867,19 @@ def main():
if accelerator.is_main_process:
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
log_validation(
text_encoder, tokenizer, prior.attn_processors, args, accelerator, weight_dtype, global_step
)
log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, global_step)
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
prior = accelerator.unwrap_model(prior)
prior = prior.to(torch.float32)
prior_lora_state_dict = get_peft_model_state_dict(prior)
WuerstchenPriorPipeline.save_lora_weights(
os.path.join(args.output_dir, "prior_lora"),
unet_lora_layers=lora_layers,
save_directory=args.output_dir,
unet_lora_layers=prior_lora_state_dict,
)
# Run a final round of inference.
......@@ -849,11 +890,12 @@ def main():
args.pretrained_decoder_model_name_or_path,
prior_text_encoder=accelerator.unwrap_model(text_encoder),
prior_tokenizer=tokenizer,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device, torch_dtype=weight_dtype)
# load lora weights
pipeline.prior_pipe.load_lora_weights(os.path.join(args.output_dir, "prior_lora"))
pipeline = pipeline.to(accelerator.device)
# load lora weights
pipeline.prior_pipe.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
pipeline.set_progress_bar_config(disable=True)
if args.seed is None:
......@@ -862,7 +904,7 @@ def main():
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
for i in range(len(args.validation_prompts)):
with torch.autocast("cuda"):
with torch.cuda.amp.autocast():
image = pipeline(
args.validation_prompts[i],
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment