from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from transformers import AutoProcessor, CLIPVisionModelWithProjection from diffusers.optimization import get_scheduler from diffusers.models import AutoencoderKL, UNet2DConditionModel, UNet2DModel from diffusers import UniPCMultistepScheduler,PNDMScheduler from pathlib import Path from safetensors.torch import save_file import torch.nn.functional as F import lightning as L import torch.nn as nn import torch import sys PROJECT_DIR = Path(__file__).resolve().parent.parent.parent sys.path.append(str(PROJECT_DIR)) from ootd.pipelines_ootd.unet_vton_2d_condition import UNetVton2DConditionModel from ootd.pipelines_ootd.unet_garm_2d_condition import UNetGarm2DConditionModel class OOTDiffusion(L.LightningModule): def __init__(self, vae_path, unet_path, model_path, vit_path, scheduler_path, mtype, batch_size, max_length, lr, lr_scheduler, beta1: float = 0.9, beta2: float = 0.99, weight_decay: float = 5e-8, eps: float = 1e-8, num_warmup_steps: int = 500, num_training_steps: int = 10000, num_cycles: int = 1, power: int = 1, conditioning_dropout_prot: float = 0.1 ): super().__init__() self.vae_path = vae_path self.unet_path = unet_path self.scheduler_path = scheduler_path self.model_path = model_path self.vit_path = vit_path self.mtype = mtype self.batch_size = batch_size self.max_length = max_length self.conditioning_dropout_prob = conditioning_dropout_prot self.lr = lr self.lr_scheduler = lr_scheduler self.beta1, self.beta2 = beta1, beta2 self.weight_decay = weight_decay self.eps = eps self.num_warmup_steps = num_warmup_steps self.num_training_steps = num_training_steps self.num_cycles = num_cycles self.power = power self.init_models() self.automatic_optimization = False def init_models(self): self.vae = AutoencoderKL.from_pretrained( self.vae_path, subfolder="vae", torch_dtype=torch.float32, # 不训练 ) self.unet_garm = UNetGarm2DConditionModel.from_pretrained( self.unet_path, subfolder="unet_garm", torcch_dtype=torch.float32, use_safetensors=True ) self.unet_vton = UNetVton2DConditionModel.from_pretrained( self.unet_path, subfolder="unet_vton", torch_dtype=torch.float32, use_safetensors=True ) # 修改模型通道数,适应输入数据 if self.unet_vton.conv_in.in_channels == 4: with torch.no_grad(): new_in_channels = 8 conv_new = nn.Conv2d( in_channels=new_in_channels, out_channels=self.unet_vton.conv_in.out_channels, kernel_size=3, padding=1 ) conv_new.weight.data.fill_(0) conv_new.weight.data[:, :4] = self.unet_vton.conv_in.weight.data conv_new.bias.data = self.unet_vton.conv_in.bias.data self.unet_vton.conv_in = conv_new self.tokenizer = CLIPTokenizer.from_pretrained( self.model_path, subfolder="tokenizer" ) self.text_encoder = CLIPTextModel.from_pretrained( self.model_path, subfolder="text_encoder" ) self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( self.vit_path ) self.auto_processor = AutoProcessor.from_pretrained(self.vit_path) self.unet_garm.requires_grad_(True) self.unet_vton.requires_grad_(True) self.vae.requires_grad_(False).eval() self.image_encoder.requires_grad_(False).eval() self.text_encoder.requires_grad_(False).eval() vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) self.noise_scheduler = PNDMScheduler.from_config(self.scheduler_path) self.vae.to(self.device) self.unet_garm.to(self.device) self.unet_vton.to(self.device) self.image_encoder.to(self.device) self.text_encoder.to(self.device) def configure_optimizers(self): params_to_optimize = list(self.unet_garm.parameters()) + list(self.unet_vton.parameters()) optimizer = torch.optim.AdamW( params_to_optimize, lr=self.lr, betas=(self.beta1, self.beta2), weight_decay=self.weight_decay, eps=self.eps ) lr_scheduler = get_scheduler( name=self.lr_scheduler, optimizer=optimizer, num_warmup_steps=self.num_warmup_steps, num_training_steps=self.num_training_steps, num_cycles=self.num_cycles, power=self.power ) return [optimizer], [lr_scheduler] def tokenize_captions(self, captions, max_length): inputs = self.tokenizer( captions, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" ) return inputs.input_ids def forward(self, x): image_garm = x['cloth']['paired'].to(self.device) image_vton = x['img_agnostic'].to(self.device) image_ori = x['img'].to(self.device) with torch.no_grad(): # get prompt embeds prompt_image = self.auto_processor(images=image_garm, return_tensors="pt").to(self.device) prompt_image = self.image_encoder(prompt_image.data['pixel_values']).image_embeds prompt_image = prompt_image.unsqueeze(1) if self.mtype == "hd": prompt_embeds = self.text_encoder(self.tokenize_captions(['']*self.batch_size, 2).to(self.device))[0] prompt_embeds[:, 1:] = prompt_image[:] elif self.mtype == "dc": # prompt_embeds = self.text_encoder(self.tokenize_captions([category], 3)) raise NotImplemented else: raise NotImplementedError # [0,1] -> [-1, 1] image_garm = self.image_processor.preprocess(image_garm) image_vton = self.image_processor.preprocess(image_vton) image_ori = self.image_processor.preprocess(image_ori) # covert images to latent space latents = self.vae.encode(image_ori).latent_dist.sample() latents = latents * self.vae.config.scaling_factor # sample noise that add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # sample a random timestep for each image timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (bsz, )).to(self.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # forward diffusion process noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) # Encode input prompt prompt_embeds = prompt_embeds.to(self.device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method num_images_per_prompt = 1 prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed*num_images_per_prompt, seq_len, -1) # prepare image latents image_latents_garm = self.vae.encode(image_garm).latent_dist.mode() image_latents_garm = torch.cat([image_latents_garm], dim=0) image_latents_vton = self.vae.encode(image_vton).latent_dist.mode() image_latents_vton = torch.cat([image_latents_vton], dim=0) # dropout the cloth condition if self.conditioning_dropout_prob is not None: random_p = torch.rand(bsz).to(self.device) # sample masks for the cloth images image_mask_dtype = image_latents_garm.dtype image_mask = 1 - ( (random_p >= self.conditioning_dropout_prob).to(image_mask_dtype) * (random_p < 3*self.conditioning_dropout_prob).to(image_mask_dtype) ) image_mask = image_mask.reshape(bsz, 1, 1, 1) image_latents_garm = image_mask * image_latents_garm sample, spatial_attn_outputs = self.unet_garm( image_latents_garm, 0, encoder_hidden_states=prompt_embeds, return_dict=False ) latent_vton_model_input = torch.cat([noisy_latents, image_latents_vton], dim=1) spatial_attn_inputs = spatial_attn_outputs.copy() noise_pred = self.unet_vton( latent_vton_model_input, spatial_attn_inputs, timesteps, encoder_hidden_states=prompt_embeds, return_dict=False )[0] util_adv_loss = torch.nn.functional.softplus(-sample[0]).mean() * 0 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + util_adv_loss return loss def training_step(self, batch): opt = self.optimizers() sch = self.lr_schedulers() loss = self(batch) self.manual_backward(loss) self.log('loss', loss, prog_bar=True) opt.step() sch.step() opt.zero_grad() return loss def on_save_checkpoint(self, checkpoint): save_file(self.unet_garm.state_dict(), "checkpoints/unet_garm/diffusion_pytorch_model.safetensors") save_file(self.unet_vton.state_dict(), "checkpoints/unet_vton/diffusion_pytorch_model.safetensors") if __name__ == "__main__": model = OOTDiffusion( vae_path="/home/modelzoo/OOTDiffusion/checkpoints/ootd", unet_path="/home/modelzoo/OOTDiffusion/checkpoints/ootd/ootd_dc/checkpoint-36000", model_path="/home/modelzoo/OOTDiffusion/checkpoints/ootd", vit_path="/home/modelzoo/OOTDiffusion/checkpoints/clip-vit-large-patch14", scheduler_path="/home/modelzoo/OOTDiffusion/checkpoints/ootd/scheduler", mtype="hd", batch_size=1, max_length=128 )