train_decoder.py 5.75 KB
Newer Older
dongchy920's avatar
dongchy920 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
print("---------------- Train_2.py --------------------")
import torch
from dalle2_pytorch.tokenizer import SimpleTokenizer
from dalle2_pytorch import OpenAIClipAdapter, Unet, Decoder, DecoderTrainer
import torchvision.transforms as T
from torchvision.utils import save_image
from PIL import Image
import os
import torch.utils.data as data
import pickle
from datasets import load_dataset, concatenate_datasets
""" got the base fot this here: https://github.com/lucidrains/DALLE2-pytorch/issues/279"""
import pdb
import time

# Parameters
image_size = 256  # Image dimension
batch_size = 2  # Batch size for training, adjust based on GPU memory
learning_rate = 1e-4  # Learning rate for the optimizer
num_epochs = 3  # Number of epochs for training
log_image_interval = 1000  # Interval for logging images
checkp_interval = 2500
log_idx = 100
save_dir = "./T2_log_images"  # Directory to save log images
weight_dir = "./T2_weight_log"
os.makedirs(save_dir, exist_ok=True)  # Create save directory if it doesn't exist
os.makedirs(weight_dir, exist_ok=True)

def xosc2ImageDataset():
    with open('../Dataset_dictionary.pkl', 'rb') as f:
        loaded_dict = pickle.load(f)
    dset = loaded_dict
    return dset


class ImgTextDataset(data.Dataset):
    def __init__(self, data):
        self.img_paths = data[f"train"]['image_path']
        self.captions = data[f"train"]['text']
        # Apply required image transforms. For my model I needed images with 256 x 256 dimensions.
        self.image_transform = T.Compose([
            T.Resize((256, 256)),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        image_path = self.img_paths[idx]
        caption = self.captions[idx]

        image = Image.open(image_path)
        image_pt = self.image_transform(image)
        if image_pt.shape[0]==1:
            image_pt = image_pt.repeat(3,1,1)
        return image_pt, caption

# Setup device
device = torch.device("cuda")

# Define your image-text dataset
data = load_dataset('json',data_files='data.json')
dataset = ImgTextDataset(data)
num_workers = 0
dataloader = torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, shuffle=True)

# Initialize OpenAI CLIP model adapter
clip = OpenAIClipAdapter()
# Create models for training
unet1 = Unet(
    dim=128,
    image_embed_dim=512,
    text_embed_dim=512,
    cond_dim=128,
    channels=3,
    dim_mults=(1, 2, 4, 8),
    cond_on_text_encodings = True    # set to True for any unets that need to be conditioned on text encodings
).cuda()

unet2 = Unet(
    dim = 16,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults = (1, 2, 4, 8, 16)
).cuda()


decoder = Decoder(
    unet = (unet1, unet2),      # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
    image_sizes = (256, 512),   # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
    clip = clip,
    timesteps = 100
).cuda()

decoder_trainer = DecoderTrainer(
    decoder,
    lr=3e-4,
    wd=1e-2,
    ema_beta=0.99,
    ema_update_after_step=1000,
    ema_update_every=10,
).cuda()

# Use built-in tokenizer. You can use others like GPT2, YTTM etc.
t = SimpleTokenizer()

# Training loop.
# Iterate over the dataloader and pass image tensors and tokenized text to the training wrapper.
# Repeat process N times.

if os.listdir(weight_dir):
    # Lade den letzten Checkpoint mit der höchsten Epochennummer
    last_epoch = max([int(name.split('_')[-2]) for name in os.listdir(weight_dir)])
    last_batch_idx = max([int(name.split('_')[-1].split('.')[0]) for name in os.listdir(weight_dir) if name.startswith(f'model_decoder_{last_epoch}')])
    checkpoint_path = os.path.join(weight_dir, f'model_decoder_{last_epoch}_{last_batch_idx}.pt')
    print("Loading from: ",checkpoint_path)
    checkpoint = torch.load(checkpoint_path)
    decoder_trainer.load_state_dict(checkpoint)
    start_epoch = last_epoch + 1  # Beginne die Epoche nach der geladenen Epoche
    print("Checkpoint loaded")
else:
    start_epoch = 0  # Beginne von der ersten Epoche
    print("starting from zero")


for epoch in range(num_epochs):
    for batch_idx, (images, texts) in enumerate(dataloader):
        images_copy = images.clone().detach()
        text=t.tokenize(texts, context_length = 1024)
        text_copy = text.clone().detach()
        ep = epoch + start_epoch   
        t0=time.time()
        # with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,],record_shapes=True, profile_memory=False, with_stack=False) as prof:
        loss = decoder_trainer(
            images_copy.cuda(),
            text=text_copy.cuda(),
            unet_number=1,
            max_batch_size=4
        )
        # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
        # 导出更加详细的prof的josn文件
        # prof.export_chrome_trace('./model_prof_nv.json')

        # pdb.set_trace()
        # print(time.time()-t0)
        decoder_trainer.update(1)
        if batch_idx % 10 == 0:
            print(f"epoch {epoch}, step {batch_idx}, loss {loss}")
        if batch_idx % log_image_interval == 0 and batch_idx != 0:
            image_embed = clip.embed_image(images.cuda())
            sample = decoder_trainer.sample(image_embed=image_embed[0], text=t.tokenize(texts).cuda())
            save_image(sample, f'./T2_log_images/{ep}_{batch_idx}.png')
        if batch_idx % checkp_interval == 0: # Periodically save the model.
            decoder_trainer.save(f'./T2_weight_log/model_decoder_{ep}_{batch_idx}.pt')
            #torch.save(decoder_trainer.state_dict(), f'./T2_weight_log/model_decoder_{ep}_{batch_idx}.pt')