print("---------------- Train_2.py --------------------") import torch from dalle2_pytorch.tokenizer import SimpleTokenizer from dalle2_pytorch import OpenAIClipAdapter, Unet, Decoder, DecoderTrainer import torchvision.transforms as T from torchvision.utils import save_image from PIL import Image import os import torch.utils.data as data import pickle from datasets import load_dataset, concatenate_datasets """ got the base fot this here: https://github.com/lucidrains/DALLE2-pytorch/issues/279""" import pdb import time # Parameters image_size = 256 # Image dimension batch_size = 2 # Batch size for training, adjust based on GPU memory learning_rate = 1e-4 # Learning rate for the optimizer num_epochs = 3 # Number of epochs for training log_image_interval = 1000 # Interval for logging images checkp_interval = 2500 log_idx = 100 save_dir = "./T2_log_images" # Directory to save log images weight_dir = "./T2_weight_log" os.makedirs(save_dir, exist_ok=True) # Create save directory if it doesn't exist os.makedirs(weight_dir, exist_ok=True) def xosc2ImageDataset(): with open('../Dataset_dictionary.pkl', 'rb') as f: loaded_dict = pickle.load(f) dset = loaded_dict return dset class ImgTextDataset(data.Dataset): def __init__(self, data): self.img_paths = data[f"train"]['image_path'] self.captions = data[f"train"]['text'] # Apply required image transforms. For my model I needed images with 256 x 256 dimensions. self.image_transform = T.Compose([ T.Resize((256, 256)), T.ToTensor() ]) def __len__(self): return len(self.img_paths) def __getitem__(self, idx): image_path = self.img_paths[idx] caption = self.captions[idx] image = Image.open(image_path) image_pt = self.image_transform(image) if image_pt.shape[0]==1: image_pt = image_pt.repeat(3,1,1) return image_pt, caption # Setup device device = torch.device("cuda") # Define your image-text dataset data = load_dataset('json',data_files='data.json') dataset = ImgTextDataset(data) num_workers = 0 dataloader = torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, shuffle=True) # Initialize OpenAI CLIP model adapter clip = OpenAIClipAdapter() # Create models for training unet1 = Unet( dim=128, image_embed_dim=512, text_embed_dim=512, cond_dim=128, channels=3, dim_mults=(1, 2, 4, 8), cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings ).cuda() unet2 = Unet( dim = 16, image_embed_dim = 512, cond_dim = 128, channels = 3, dim_mults = (1, 2, 4, 8, 16) ).cuda() decoder = Decoder( unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here) image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in) clip = clip, timesteps = 100 ).cuda() decoder_trainer = DecoderTrainer( decoder, lr=3e-4, wd=1e-2, ema_beta=0.99, ema_update_after_step=1000, ema_update_every=10, ).cuda() # Use built-in tokenizer. You can use others like GPT2, YTTM etc. t = SimpleTokenizer() # Training loop. # Iterate over the dataloader and pass image tensors and tokenized text to the training wrapper. # Repeat process N times. if os.listdir(weight_dir): # Lade den letzten Checkpoint mit der höchsten Epochennummer last_epoch = max([int(name.split('_')[-2]) for name in os.listdir(weight_dir)]) last_batch_idx = max([int(name.split('_')[-1].split('.')[0]) for name in os.listdir(weight_dir) if name.startswith(f'model_decoder_{last_epoch}')]) checkpoint_path = os.path.join(weight_dir, f'model_decoder_{last_epoch}_{last_batch_idx}.pt') print("Loading from: ",checkpoint_path) checkpoint = torch.load(checkpoint_path) decoder_trainer.load_state_dict(checkpoint) start_epoch = last_epoch + 1 # Beginne die Epoche nach der geladenen Epoche print("Checkpoint loaded") else: start_epoch = 0 # Beginne von der ersten Epoche print("starting from zero") for epoch in range(num_epochs): for batch_idx, (images, texts) in enumerate(dataloader): images_copy = images.clone().detach() text=t.tokenize(texts, context_length = 1024) text_copy = text.clone().detach() ep = epoch + start_epoch t0=time.time() # with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,],record_shapes=True, profile_memory=False, with_stack=False) as prof: loss = decoder_trainer( images_copy.cuda(), text=text_copy.cuda(), unet_number=1, max_batch_size=4 ) # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) # 导出更加详细的prof的josn文件 # prof.export_chrome_trace('./model_prof_nv.json') # pdb.set_trace() # print(time.time()-t0) decoder_trainer.update(1) if batch_idx % 10 == 0: print(f"epoch {epoch}, step {batch_idx}, loss {loss}") if batch_idx % log_image_interval == 0 and batch_idx != 0: image_embed = clip.embed_image(images.cuda()) sample = decoder_trainer.sample(image_embed=image_embed[0], text=t.tokenize(texts).cuda()) save_image(sample, f'./T2_log_images/{ep}_{batch_idx}.png') if batch_idx % checkp_interval == 0: # Periodically save the model. decoder_trainer.save(f'./T2_weight_log/model_decoder_{ep}_{batch_idx}.pt') #torch.save(decoder_trainer.state_dict(), f'./T2_weight_log/model_decoder_{ep}_{batch_idx}.pt')