Unverified Commit 01782c22 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Wuerstchen] Adapt lora training example scripts to use PEFT (#5959)

* Adapt lora example scripts to use PEFT

* add to_out.0
parent d63a498c
...@@ -5,3 +5,4 @@ wandb ...@@ -5,3 +5,4 @@ wandb
huggingface-cli huggingface-cli
bitsandbytes bitsandbytes
deepspeed deepspeed
peft>=0.6.0
...@@ -31,14 +31,14 @@ from accelerate.utils import ProjectConfiguration, set_seed ...@@ -31,14 +31,14 @@ from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import create_repo, hf_hub_download, upload_folder from huggingface_hub import create_repo, hf_hub_download, upload_folder
from modeling_efficient_net_encoder import EfficientNetEncoder from modeling_efficient_net_encoder import EfficientNetEncoder
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from torchvision import transforms from torchvision import transforms
from tqdm import tqdm from tqdm import tqdm
from transformers import CLIPTextModel, PreTrainedTokenizerFast from transformers import CLIPTextModel, PreTrainedTokenizerFast
from transformers.utils import ContextManagers from transformers.utils import ContextManagers
from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler, WuerstchenPriorPipeline from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler, WuerstchenPriorPipeline
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
...@@ -139,17 +139,17 @@ More information on all the CLI arguments and the environment are available on y ...@@ -139,17 +139,17 @@ More information on all the CLI arguments and the environment are available on y
f.write(yaml + model_card) f.write(yaml + model_card)
def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator, weight_dtype, epoch): def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch):
logger.info("Running validation... ") logger.info("Running validation... ")
pipeline = AutoPipelineForText2Image.from_pretrained( pipeline = AutoPipelineForText2Image.from_pretrained(
args.pretrained_decoder_model_name_or_path, args.pretrained_decoder_model_name_or_path,
prior=accelerator.unwrap_model(prior),
prior_text_encoder=accelerator.unwrap_model(text_encoder), prior_text_encoder=accelerator.unwrap_model(text_encoder),
prior_tokenizer=tokenizer, prior_tokenizer=tokenizer,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
pipeline.prior_prior.set_attn_processor(attn_processors)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
if args.seed is None: if args.seed is None:
...@@ -159,7 +159,7 @@ def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator, ...@@ -159,7 +159,7 @@ def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator,
images = [] images = []
for i in range(len(args.validation_prompts)): for i in range(len(args.validation_prompts)):
with torch.autocast("cuda"): with torch.cuda.amp.autocast():
image = pipeline( image = pipeline(
args.validation_prompts[i], args.validation_prompts[i],
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
...@@ -167,7 +167,6 @@ def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator, ...@@ -167,7 +167,6 @@ def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator,
height=args.resolution, height=args.resolution,
width=args.resolution, width=args.resolution,
).images[0] ).images[0]
images.append(image) images.append(image)
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
...@@ -527,11 +526,50 @@ def main(): ...@@ -527,11 +526,50 @@ def main():
prior.to(accelerator.device, dtype=weight_dtype) prior.to(accelerator.device, dtype=weight_dtype)
# lora attn processor # lora attn processor
lora_attn_procs = {} prior_lora_config = LoraConfig(
for name in prior.attn_processors.keys(): r=args.rank, target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=prior.config["c"], rank=args.rank) )
prior.set_attn_processor(lora_attn_procs) prior.add_adapter(prior_lora_config)
lora_layers = AttnProcsLayers(prior.attn_processors)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
prior_lora_layers_to_save = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(prior))):
prior_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
WuerstchenPriorPipeline.save_lora_weights(
output_dir,
unet_lora_layers=prior_lora_layers_to_save,
)
def load_model_hook(models, input_dir):
prior_ = None
while len(models) > 0:
model = models.pop()
if isinstance(model, type(accelerator.unwrap_model(prior))):
prior_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alphas = WuerstchenPriorPipeline.lora_state_dict(input_dir)
WuerstchenPriorPipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=prior_)
WuerstchenPriorPipeline.load_lora_into_text_encoder(
lora_state_dict,
network_alphas=network_alphas,
)
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
if args.allow_tf32: if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
...@@ -547,8 +585,9 @@ def main(): ...@@ -547,8 +585,9 @@ def main():
optimizer_cls = bnb.optim.AdamW8bit optimizer_cls = bnb.optim.AdamW8bit
else: else:
optimizer_cls = torch.optim.AdamW optimizer_cls = torch.optim.AdamW
params_to_optimize = list(filter(lambda p: p.requires_grad, prior.parameters()))
optimizer = optimizer_cls( optimizer = optimizer_cls(
lora_layers.parameters(), params_to_optimize,
lr=args.learning_rate, lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
...@@ -674,8 +713,8 @@ def main(): ...@@ -674,8 +713,8 @@ def main():
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
) )
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
lora_layers, optimizer, train_dataloader, lr_scheduler prior, optimizer, train_dataloader, lr_scheduler
) )
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
...@@ -782,7 +821,7 @@ def main(): ...@@ -782,7 +821,7 @@ def main():
# Backpropagate # Backpropagate
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
accelerator.clip_grad_norm_(lora_layers.parameters(), args.max_grad_norm) accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
...@@ -828,17 +867,19 @@ def main(): ...@@ -828,17 +867,19 @@ def main():
if accelerator.is_main_process: if accelerator.is_main_process:
if args.validation_prompts is not None and epoch % args.validation_epochs == 0: if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
log_validation( log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, global_step)
text_encoder, tokenizer, prior.attn_processors, args, accelerator, weight_dtype, global_step
)
# Create the pipeline using the trained modules and save it. # Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
prior = accelerator.unwrap_model(prior)
prior = prior.to(torch.float32) prior = prior.to(torch.float32)
prior_lora_state_dict = get_peft_model_state_dict(prior)
WuerstchenPriorPipeline.save_lora_weights( WuerstchenPriorPipeline.save_lora_weights(
os.path.join(args.output_dir, "prior_lora"), save_directory=args.output_dir,
unet_lora_layers=lora_layers, unet_lora_layers=prior_lora_state_dict,
) )
# Run a final round of inference. # Run a final round of inference.
...@@ -849,11 +890,12 @@ def main(): ...@@ -849,11 +890,12 @@ def main():
args.pretrained_decoder_model_name_or_path, args.pretrained_decoder_model_name_or_path,
prior_text_encoder=accelerator.unwrap_model(text_encoder), prior_text_encoder=accelerator.unwrap_model(text_encoder),
prior_tokenizer=tokenizer, prior_tokenizer=tokenizer,
torch_dtype=weight_dtype,
) )
pipeline = pipeline.to(accelerator.device, torch_dtype=weight_dtype) pipeline = pipeline.to(accelerator.device)
# load lora weights
pipeline.prior_pipe.load_lora_weights(os.path.join(args.output_dir, "prior_lora"))
# load lora weights
pipeline.prior_pipe.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
if args.seed is None: if args.seed is None:
...@@ -862,7 +904,7 @@ def main(): ...@@ -862,7 +904,7 @@ def main():
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
for i in range(len(args.validation_prompts)): for i in range(len(args.validation_prompts)):
with torch.autocast("cuda"): with torch.cuda.amp.autocast():
image = pipeline( image = pipeline(
args.validation_prompts[i], args.validation_prompts[i],
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment