import os import numpy as np import cv2 import io from datasets import load_dataset import argparse from pathlib import Path import json import time import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from accelerate import Accelerator from accelerate.utils import ProjectConfiguration from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel from transformers import CLIPTextModel, CLIPTokenizer import functools from ..linfusion import LinFusion def get_submodule(model, module_name): return functools.reduce(getattr, module_name.split("."), model) # Dataset def get_laion_dataset( tokenizer, resolution=512, path="bhargavsdesai/laion_improved_aesthetics_6.5plus_with_images", ): dataset = load_dataset(path, split="train") with open( "./assets/laion_improved_aesthetics_6.5plus_with_images_blip_captions.json" ) as read_file: all_captions = json.load(read_file) def get_blip_caption(example, idx): captions = [all_captions[item] for item in idx] example["input_ids"] = tokenizer( captions, padding="max_length", truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt", ).input_ids return example dataset = dataset.map(get_blip_caption, with_indices=True, batched=True) def process(image): img = np.array(image) img = cv2.resize(img, (resolution, resolution), interpolation=cv2.INTER_CUBIC) img = np.array(img).astype(np.float32) img = img / 127.5 - 1.0 return torch.from_numpy(img).permute(2, 0, 1) def transform(example): batch = {} images = [ Image.open(io.BytesIO(item["bytes"])).convert("RGB") for item in example["image"] ] batch["image"] = torch.stack([process(image) for image in images], dim=0) batch["text_input_ids"] = torch.from_numpy( np.array(example["input_ids"]) ).long() return batch dataset.set_transform(transform) return dataset def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--pretrained_linfusion_path", type=str, default=None, help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.", ) parser.add_argument( "--data_json_file", type=str, default=None, help="Training data", ) parser.add_argument( "--data_root_path", type=str, default=None, help="Training data root path", ) parser.add_argument( "--output_dir", type=str, default="linear_attn_tune_attn", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--logging_dir", type=str, default="logs", help=( "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ), ) parser.add_argument( "--resolution", type=int, default=512, help=("The resolution for input images"), ) parser.add_argument( "--learning_rate", type=float, default=1e-4, help="Learning rate to use.", ) parser.add_argument( "--weight_decay", type=float, default=0, help="Weight decay to use." ) parser.add_argument("--num_train_epochs", type=int, default=300) parser.add_argument( "--train_batch_size", type=int, default=6, help="Batch size (per device) for the training dataloader.", ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=2, help="Gradient accumulation steps. Total bs=train_batch_size * gradient_accumulation_steps * num_gpus", ) parser.add_argument( "--dataloader_num_workers", type=int, default=8, help=( "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), ) parser.add_argument( "--save_steps", type=int, default=10000, help=("Save a checkpoint of the training state every X updates"), ) parser.add_argument( "--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument( "--report_to", type=str, default="tensorboard", help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) parser.add_argument( "--local_rank", type=int, default=-1, help="For distributed training: local_rank", ) parser.add_argument( "--mid_dim_scale", type=int, default=None, help="The scale of the mid_dim of the linear attention. `mid_dim = dim_n // mid_dim_scale`", ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank return args def main(): args = parse_args() logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration( project_dir=args.output_dir, logging_dir=logging_dir ) accelerator = Accelerator( mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, gradient_accumulation_steps=args.gradient_accumulation_steps, ) if accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) # Load scheduler, tokenizer and models. noise_scheduler = DDPMScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder="scheduler" ) tokenizer = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer" ) text_encoder = CLIPTextModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder" ) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae" ) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet" ) unet_teacher = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet" ) # freeze parameters of models to save more memory unet.requires_grad_(False) unet_teacher.requires_grad_(False) vae.requires_grad_(False) text_encoder.requires_grad_(False) all_attn_outputs = [] all_attn_outputs_teacher = [] if args.pretrained_linfusion_path is not None: # to construct a LinFusion model linfusion_model = LinFusion.construct_for( unet=unet, load_pretrained=True, pretrained_model_name_or_path=args.pretrained_linfusion_path, ) else: linfusion_config = LinFusion.get_default_config(unet=unet) if args.mid_dim_scale is not None: for each in linfusion_config["modules_list"]: each["projection_mid_dim"] = each["dim_n"] // args.mid_dim_scale linfusion_model = LinFusion(**linfusion_config) linfusion_model.mount_to(unet=unet) def student_forward_hook(module, input, output): all_attn_outputs.append(output) def teacher_forward_hook(module, input, output): all_attn_outputs_teacher.append(output) for sub_module in linfusion_model.modules_list: sub_module_name = sub_module["module_name"] student_module = get_submodule(unet, sub_module_name) teacher_module = get_submodule(unet_teacher, sub_module_name) student_module.register_forward_hook(student_forward_hook) teacher_module.register_forward_hook(teacher_forward_hook) weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 unet_teacher.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) linfusion_model.requires_grad_(True) # optimizer optimizer = torch.optim.AdamW( linfusion_model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, ) # dataloader train_dataset = get_laion_dataset(tokenizer=tokenizer, resolution=args.resolution) train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=True, batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers, ) # Prepare everything with our `accelerator`. linfusion_model, unet, optimizer, train_dataloader = accelerator.prepare( linfusion_model, unet, optimizer, train_dataloader ) global_step = 0 for epoch in range(0, args.num_train_epochs): begin = time.perf_counter() for step, batch in enumerate(train_dataloader): load_data_time = time.perf_counter() - begin with accelerator.accumulate(linfusion_model): # Convert images to latent space with torch.no_grad(): latents = vae.encode( batch["image"].to(accelerator.device, dtype=weight_dtype) ).latent_dist.sample() latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint( 0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device, ) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) with torch.no_grad(): encoder_hidden_states = text_encoder( batch["text_input_ids"].to(accelerator.device) )[0] noise_pred_teacher = unet_teacher( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, ).sample noise_pred = unet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, ).sample if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError( f"Unknown prediction type {noise_scheduler.config.prediction_type}" ) loss_noise = F.mse_loss( noise_pred.float(), target.float(), reduction="mean" ) loss_kd = F.mse_loss( noise_pred.float(), noise_pred_teacher.float(), reduction="mean" ) loss_feat = sum( [ F.mse_loss(feat.float(), feat_teacher.float()) for feat, feat_teacher in zip( all_attn_outputs, all_attn_outputs_teacher ) ] ) / len(all_attn_outputs_teacher) loss = loss_noise + loss_kd * 0.5 + loss_feat * 0.5 # Gather the losses across all processes for logging (if we use distributed training). avg_loss = ( accelerator.gather(loss.repeat(args.train_batch_size)).mean().item() ) avg_loss_kd = ( accelerator.gather(loss_kd.repeat(args.train_batch_size)) .mean() .item() ) avg_loss_feat = ( accelerator.gather(loss_feat.repeat(args.train_batch_size)) .mean() .item() ) avg_loss_noise = ( accelerator.gather(loss_noise.repeat(args.train_batch_size)) .mean() .item() ) # Backpropagate accelerator.backward(loss) optimizer.step() optimizer.zero_grad() all_attn_outputs.clear() all_attn_outputs_teacher.clear() if accelerator.is_main_process: print( "Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}, step_loss_noise: {}, step_loss_kd: {}, step_loss_feat: {}".format( epoch, step, load_data_time, time.perf_counter() - begin, avg_loss, avg_loss_noise, avg_loss_kd, avg_loss_feat, ) ) global_step += 1 if global_step % args.save_steps == 0 and accelerator.is_main_process: # Save model checkpoint linfusion_model.save_pretrained( os.path.join(args.output_dir, f"linfusion-{global_step}"), push_to_hub=False, ) begin = time.perf_counter() if __name__ == "__main__": main()