print("---------------- Train_2_Priortrainer.py --------------------") import torch from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, OpenAIClipAdapter, DiffusionPriorTrainer from dalle2_pytorch.tokenizer import SimpleTokenizer import torchvision.transforms as T from PIL import Image import pickle import os import torch.utils.data as data from torch.nn.utils.rnn import pad_sequence from datasets import load_dataset, concatenate_datasets from accelerate import Accelerator import pdb device = torch.device("cuda") weight_dir = "./Priortrainer_weight_log" os.makedirs(weight_dir, exist_ok=True) num_epochs = 3 #batch_idx = 196491 checkp_interval = 35000 batch_size = 2 clip = OpenAIClipAdapter() def xosc2ImageDataset(): with open('../Dataset_dictionary.pkl', 'rb') as f: loaded_dict = pickle.load(f) dset = loaded_dict return dset class TextDataset: def __init__(self, texts, batch_size=4, max_length=4500): self.texts = texts self.batch_size = batch_size self.max_length = max_length def __len__(self): return len(self.texts) def __iter__(self): self.current_index = 0 # Setze den Index zu Beginn der Iteration zurück return self def __next__(self): batch_texts = [] for _ in range(self.batch_size): if self.current_index >= len(self.texts): raise StopIteration text = self.texts[self.current_index] #print(text[:90]) text = text[:self.max_length] tensor = torch.tensor([ord(char) for char in text]) batch_texts.append(tensor) self.current_index += 1 padded_tensors = pad_sequence(batch_texts, batch_first=True) return padded_tensors class ImageDataset: def __init__(self, image_paths, batch_size=4, image_size=(256, 256)): self.image_paths = image_paths self.batch_size = batch_size self.image_size = image_size def __len__(self): return len(self.image_paths) def __iter__(self): self.current_index = 0 # Setze den Index zu Beginn der Iteration zurück return self def __next__(self): batch_images = [] for _ in range(self.batch_size): if self.current_index >= len(self.image_paths): raise StopIteration path = self.image_paths[self.current_index] normalized_path = path.replace('\\', '/') #print(normalized_path) image = self.load_image(normalized_path) batch_images.append(image) self.current_index += 1 return torch.stack(batch_images, dim=0) def load_image(self, path): transform = T.Compose([ T.Resize(self.image_size), T.ToTensor(), ]) image = Image.open(path).convert("RGB") image = transform(image) return image data = load_dataset('json',data_files='data.json') image_list = data[f"train"]['image_path'] # list of captions text_list = data[f"train"]['text'] text_dataset = TextDataset(text_list, batch_size=batch_size) image_dataset = ImageDataset(image_list, batch_size=batch_size) """prior networks (with transformer)""" #setup prior network, which contains an autoregressive transformer prior_network = DiffusionPriorNetwork( dim = 512, depth = 6, dim_head = 64, heads = 8 ).cuda() diffusion_prior = DiffusionPrior(# diffusion prior network, which contains the CLIP and network (with transformer) above net = prior_network, clip = clip, timesteps = 1000, sample_timesteps = 64, cond_drop_prob = 0.2 ).cuda() accelerator = Accelerator() prior_trainer = DiffusionPriorTrainer( diffusion_prior, accelerator=accelerator, lr = 3e-4, ) if os.listdir(weight_dir): # Load the last checkpoint with the highest epoch number last_epoch = max([int(name.split('_')[-2]) for name in os.listdir(weight_dir)]) last_batch_idx = max([int(name.split('_')[-1].split('.')[0]) for name in os.listdir(weight_dir) if name.startswith(f'model_prior_{last_epoch}')]) checkpoint_path = os.path.join(weight_dir, f'model_prior_{last_epoch}_{last_batch_idx}.pt') prior_trainer.load(checkpoint_path, overwrite_lr = True, strict=True) start_epoch = last_epoch + 1 # Start next epoch plus last one print("Checkpoint loaded") else: start_epoch = 0 print("starting from zero") # checkpoint_path = './model/prior.pth' # prior_trainer.load(checkpoint_path, overwrite_lr = True, strict=True) t = SimpleTokenizer() num_batches = len(text_dataset) // batch_size # Calculate the total number of batches print("Numberofbatches",num_batches,"Length Dataset:", len(text_dataset)) for epoch in range(num_epochs): ep = epoch + start_epoch for idx in range(num_batches):# range(len(text_dataset)): text_loader = iter(text_dataset) image_loader = iter(image_dataset) #for _ in range(num_batches): batch_texts = next(text_loader) batch_images = next(image_loader) loss = prior_trainer( batch_texts.to(device), batch_images.to(device) ) prior_trainer.update() # Update the parameters of the model with the Optimizer if idx % 10 == 0: print(f"epoch {ep}, step {idx}, loss {loss}") if idx % (int(num_batches/10)) == 0: # Periodically save the model. prior_trainer.save(f'./Priortrainer_weight_log/model_prior_{ep}_{idx}.pt') # do above for many steps ...