main.py 39.2 KB
Newer Older
Fazzie's avatar
Fazzie committed
1
2
3
4
5
6
7
import argparse
import csv
import datetime
import glob
import importlib
import os
import sys
8
import time
Fazzie's avatar
Fazzie committed
9
10

import numpy as np
11
12
13
import torch
import torchvision

Fazzie's avatar
Fazzie committed
14
15
16
17
18
try:
    import lightning.pytorch as pl
except:
    import pytorch_lightning as pl

19
from functools import partial
Fazzie's avatar
Fazzie committed
20
21
22

from omegaconf import OmegaConf
from packaging import version
23
24
from PIL import Image
from prefetch_generator import BackgroundGenerator
Fazzie's avatar
Fazzie committed
25
from torch.utils.data import DataLoader, Dataset, Subset, random_split
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from ldm.models.diffusion.ddpm import LatentDiffusion
#try:
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.utilities import rank_zero_info, rank_zero_only
from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger
from lightning.pytorch.strategies import ColossalAIStrategy,DDPStrategy
LIGHTNING_PACK_NAME = "lightning.pytorch."
# #except:
#     from pytorch_lightning import seed_everything
#     from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
#     from pytorch_lightning.trainer import Trainer
#     from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
#     LIGHTNING_PACK_NAME = "pytorch_lightning."
41
42
43

from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config
Fazzie's avatar
Fazzie committed
44
45
46

# from ldm.modules.attention import enable_flash_attentions

47
48

class DataLoaderX(DataLoader):
NatalieC323's avatar
NatalieC323 committed
49
# A custom data loader class that inherits from DataLoader
50
    def __iter__(self):
NatalieC323's avatar
NatalieC323 committed
51
52
        # Overriding the __iter__ method of DataLoader to return a BackgroundGenerator
        #This is to enable data laoding in the background to improve training performance
53
54
55
56
        return BackgroundGenerator(super().__iter__())


def get_parser(**parser_kwargs):
NatalieC323's avatar
NatalieC323 committed
57
    #A function to create an ArgumentParser object and add arguments to it
Fazzie's avatar
Fazzie committed
58

59
    def str2bool(v):
NatalieC323's avatar
NatalieC323 committed
60
        # A helper function to parse boolean values from command line arguments
61
62
63
64
65
66
67
68
        if isinstance(v, bool):
            return v
        if v.lower() in ("yes", "true", "t", "y", "1"):
            return True
        elif v.lower() in ("no", "false", "f", "n", "0"):
            return False
        else:
            raise argparse.ArgumentTypeError("Boolean value expected.")
NatalieC323's avatar
NatalieC323 committed
69
    # Create an ArgumentParser object with specifies kwargs
70
    parser = argparse.ArgumentParser(**parser_kwargs)
NatalieC323's avatar
NatalieC323 committed
71
72

    # Add vairous command line arguments with their default balues and descriptions
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    parser.add_argument(
        "-n",
        "--name",
        type=str,
        const=True,
        default="",
        nargs="?",
        help="postfix for logdir",
    )
    parser.add_argument(
        "-r",
        "--resume",
        type=str,
        const=True,
        default="",
        nargs="?",
        help="resume from logdir or checkpoint in logdir",
    )
    parser.add_argument(
        "-b",
        "--base",
        nargs="*",
        metavar="base_config.yaml",
        help="paths to base configs. Loaded from left-to-right. "
Fazzie's avatar
Fazzie committed
97
        "Parameters can be overwritten or added with command-line options of the form `--key value`.",
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        default=list(),
    )
    parser.add_argument(
        "-t",
        "--train",
        type=str2bool,
        const=True,
        default=False,
        nargs="?",
        help="train",
    )
    parser.add_argument(
        "--no-test",
        type=str2bool,
        const=True,
        default=False,
        nargs="?",
        help="disable test",
    )
Fazzie's avatar
Fazzie committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    parser.add_argument(
        "-p",
        "--project",
        help="name of new or path to existing project",
    )
    parser.add_argument(
        "-c",
        "--ckpt",
        type=str,
        const=True,
        default="",
        nargs="?",
        help="load pretrained checkpoint from stable AI",
    )
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    parser.add_argument(
        "-d",
        "--debug",
        type=str2bool,
        nargs="?",
        const=True,
        default=False,
        help="enable post-mortem debugging",
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        default=23,
        help="seed for seed_everything",
    )
    parser.add_argument(
        "-f",
        "--postfix",
        type=str,
        default="",
        help="post-postfix for default name",
    )
    parser.add_argument(
        "-l",
        "--logdir",
        type=str,
        default="logs",
        help="directory for logging dat shit",
    )
    parser.add_argument(
        "--scale_lr",
        type=str2bool,
        nargs="?",
        const=True,
        default=True,
        help="scale base-lr by ngpu * batch_size * n_accumulate",
    )
Fazzie's avatar
Fazzie committed
169

170
171
    return parser

NatalieC323's avatar
NatalieC323 committed
172
# A function that returns the non-default arguments between two objects
173
def nondefault_trainer_args(opt):
NatalieC323's avatar
NatalieC323 committed
174
    # create an argument parsser
175
    parser = argparse.ArgumentParser()
NatalieC323's avatar
NatalieC323 committed
176
    # add pytorch lightning trainer default arguments
177
    parser = Trainer.add_argparse_args(parser)
NatalieC323's avatar
NatalieC323 committed
178
    # parse the empty arguments to obtain the default values
179
    args = parser.parse_args([])
NatalieC323's avatar
NatalieC323 committed
180
    # return all non-default arguments
181
182
    return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))

NatalieC323's avatar
NatalieC323 committed
183
# A dataset wrapper class to create a pytorch dataset from an arbitrary object
184
185
186
187
188
189
190
191
192
193
194
195
class WrappedDataset(Dataset):
    """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""

    def __init__(self, dataset):
        self.data = dataset

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

NatalieC323's avatar
NatalieC323 committed
196
# A function to initialize worker processes
197
198
199
200
201
202
203
def worker_init_fn(_):
    worker_info = torch.utils.data.get_worker_info()

    dataset = worker_info.dataset
    worker_id = worker_info.id

    if isinstance(dataset, Txt2ImgIterableBaseDataset):
NatalieC323's avatar
NatalieC323 committed
204
        #divide the dataset into equal parts for each worker
205
        split_size = dataset.num_records // worker_info.num_workers
NatalieC323's avatar
NatalieC323 committed
206
        #set the sample IDs for the current worker
207
208
        # reset num_records to the true number to retain reliable length information
        dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
NatalieC323's avatar
NatalieC323 committed
209
        # set the seed for the current worker
210
211
212
213
214
        current_id = np.random.choice(len(np.random.get_state()[1]), 1)
        return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
    else:
        return np.random.seed(np.random.get_state()[1][0] + worker_id)

NatalieC323's avatar
NatalieC323 committed
215
#Provide functionality for creating data loadedrs based on provided dataset configurations
216
class DataModuleFromConfig(pl.LightningDataModule):
Fazzie's avatar
Fazzie committed
217
218
219
220
221
222
223
224
225
226
227

    def __init__(self,
                 batch_size,
                 train=None,
                 validation=None,
                 test=None,
                 predict=None,
                 wrap=False,
                 num_workers=None,
                 shuffle_test_loader=False,
                 use_worker_init_fn=False,
228
229
                 shuffle_val_dataloader=False):
        super().__init__()
NatalieC323's avatar
NatalieC323 committed
230
        # Set data module attributes
231
232
233
234
        self.batch_size = batch_size
        self.dataset_configs = dict()
        self.num_workers = num_workers if num_workers is not None else batch_size * 2
        self.use_worker_init_fn = use_worker_init_fn
NatalieC323's avatar
NatalieC323 committed
235
        # If a dataset is passed, add it to the dataset configs and create a corresponding dataloader method
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        if train is not None:
            self.dataset_configs["train"] = train
            self.train_dataloader = self._train_dataloader
        if validation is not None:
            self.dataset_configs["validation"] = validation
            self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
        if test is not None:
            self.dataset_configs["test"] = test
            self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
        if predict is not None:
            self.dataset_configs["predict"] = predict
            self.predict_dataloader = self._predict_dataloader
        self.wrap = wrap

    def prepare_data(self):
NatalieC323's avatar
NatalieC323 committed
251
        # Instantiate datasets
252
253
254
255
        for data_cfg in self.dataset_configs.values():
            instantiate_from_config(data_cfg)

    def setup(self, stage=None):
NatalieC323's avatar
NatalieC323 committed
256
        # Instantiate datasets from the dataset configs
Fazzie's avatar
Fazzie committed
257
        self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
NatalieC323's avatar
NatalieC323 committed
258
259
        
        # If wrap is true, create a WrappedDataset for each dataset
260
261
262
263
264
        if self.wrap:
            for k in self.datasets:
                self.datasets[k] = WrappedDataset(self.datasets[k])

    def _train_dataloader(self):
NatalieC323's avatar
NatalieC323 committed
265
        #Check if the train dataset is iterable
266
        is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
NatalieC323's avatar
NatalieC323 committed
267
        #Set the worker initialization function of the dataset isiterable or use_worker_init_fn is True
268
269
270
271
        if is_iterable_dataset or self.use_worker_init_fn:
            init_fn = worker_init_fn
        else:
            init_fn = None
NatalieC323's avatar
NatalieC323 committed
272
        # Return a DataLoaderX object for the train dataset
Fazzie's avatar
Fazzie committed
273
274
275
276
277
        return DataLoaderX(self.datasets["train"],
                           batch_size=self.batch_size,
                           num_workers=self.num_workers,
                           shuffle=False if is_iterable_dataset else True,
                           worker_init_fn=init_fn)
278
279

    def _val_dataloader(self, shuffle=False):
NatalieC323's avatar
NatalieC323 committed
280
        #Check if the validation dataset is iterable
281
282
283
284
        if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
            init_fn = worker_init_fn
        else:
            init_fn = None
NatalieC323's avatar
NatalieC323 committed
285
        # Return a DataLoaderX object for the validation dataset
286
        return DataLoaderX(self.datasets["validation"],
Fazzie's avatar
Fazzie committed
287
288
289
290
                           batch_size=self.batch_size,
                           num_workers=self.num_workers,
                           worker_init_fn=init_fn,
                           shuffle=shuffle)
291
292

    def _test_dataloader(self, shuffle=False):
NatalieC323's avatar
NatalieC323 committed
293
        # Check if the test dataset is iterable
294
        is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
NatalieC323's avatar
NatalieC323 committed
295
        # Set the worker initialization function if the dataset is iterable or use_worker_init_fn is True
296
297
298
299
300
301
302
303
        if is_iterable_dataset or self.use_worker_init_fn:
            init_fn = worker_init_fn
        else:
            init_fn = None

        # do not shuffle dataloader for iterable dataset
        shuffle = shuffle and (not is_iterable_dataset)

Fazzie's avatar
Fazzie committed
304
305
306
307
308
        return DataLoaderX(self.datasets["test"],
                           batch_size=self.batch_size,
                           num_workers=self.num_workers,
                           worker_init_fn=init_fn,
                           shuffle=shuffle)
309
310
311
312
313
314

    def _predict_dataloader(self, shuffle=False):
        if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
            init_fn = worker_init_fn
        else:
            init_fn = None
Fazzie's avatar
Fazzie committed
315
316
317
318
        return DataLoaderX(self.datasets["predict"],
                           batch_size=self.batch_size,
                           num_workers=self.num_workers,
                           worker_init_fn=init_fn)
319
320
321


class SetupCallback(Callback):
NatalieC323's avatar
NatalieC323 committed
322
    # I nitialize the callback with the necessary parameters
Fazzie's avatar
Fazzie committed
323

324
325
326
327
328
329
330
331
332
333
    def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
        super().__init__()
        self.resume = resume
        self.now = now
        self.logdir = logdir
        self.ckptdir = ckptdir
        self.cfgdir = cfgdir
        self.config = config
        self.lightning_config = lightning_config

NatalieC323's avatar
NatalieC323 committed
334
    # Save a checkpoint if training is interrupted with keyboard interrupt
335
336
337
338
339
340
    def on_keyboard_interrupt(self, trainer, pl_module):
        if trainer.global_rank == 0:
            print("Summoning checkpoint.")
            ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
            trainer.save_checkpoint(ckpt_path)

NatalieC323's avatar
NatalieC323 committed
341
    # Create necessary directories and save configuration files before training starts
342
343
344
345
346
347
348
349
    # def on_pretrain_routine_start(self, trainer, pl_module):
    def on_fit_start(self, trainer, pl_module):
        if trainer.global_rank == 0:
            # Create logdirs and save configs
            os.makedirs(self.logdir, exist_ok=True)
            os.makedirs(self.ckptdir, exist_ok=True)
            os.makedirs(self.cfgdir, exist_ok=True)

NatalieC323's avatar
NatalieC323 committed
350
            #Create trainstep checkpoint directory if necessary
351
352
353
354
355
            if "callbacks" in self.lightning_config:
                if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:
                    os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
            print("Project config")
            print(OmegaConf.to_yaml(self.config))
Fazzie's avatar
Fazzie committed
356
            OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
357

NatalieC323's avatar
NatalieC323 committed
358
            # Save project config and lightning config as YAML files
359
360
361
362
363
            print("Lightning config")
            print(OmegaConf.to_yaml(self.lightning_config))
            OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
                           os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))

NatalieC323's avatar
NatalieC323 committed
364
        # Remove log directory if resuming training and directory already exists
365
366
367
368
369
370
371
372
373
374
375
        else:
            # ModelCheckpoint callback created log directory --- remove it
            if not self.resume and os.path.exists(self.logdir):
                dst, name = os.path.split(self.logdir)
                dst = os.path.join(dst, "child_runs", name)
                os.makedirs(os.path.split(dst)[0], exist_ok=True)
                try:
                    os.rename(self.logdir, dst)
                except FileNotFoundError:
                    pass

Fazzie's avatar
Fazzie committed
376
377
378
379
380
381
    # def on_fit_end(self, trainer, pl_module):
    #     if trainer.global_rank == 0:
    #         ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
    #         rank_zero_info(f"Saving final checkpoint in {ckpt_path}.")
    #         trainer.save_checkpoint(ckpt_path)

382

NatalieC323's avatar
NatalieC323 committed
383
# PyTorch Lightning callback for ogging images during training and validation of a deep learning model
384
class ImageLogger(Callback):
Fazzie's avatar
Fazzie committed
385
386

    def __init__(self,
NatalieC323's avatar
NatalieC323 committed
387
388
389
390
391
392
393
394
395
                 batch_frequency, # Frequency of batches on which to log images
                 max_images,      # Maximum number of images to log
                 clamp=True,      # Whether to clamp pixel values to [-1,1]
                 increase_log_steps=True,   # Whether to increase frequency of log steps exponentially
                 rescale=True,    # Whetehr to rescale pixel values to [0,1]
                 disabled=False,  # Whether to disable logging
                 log_on_batch_idx=False,   # Whether to log on baych index instead of global step
                 log_first_step=False,     # Whetehr to log on the first step
                 log_images_kwargs=None):  # Additional keyword arguments to pass to log_images method
396
397
398
399
400
        super().__init__()
        self.rescale = rescale
        self.batch_freq = batch_frequency
        self.max_images = max_images
        self.logger_log_images = {
NatalieC323's avatar
NatalieC323 committed
401
402
            # Dictionary of logger classes and their corresponding logging methods
            pl.loggers.CSVLogger: self._testtube,   
403
        }
NatalieC323's avatar
NatalieC323 committed
404
        # Create a list of exponentially increasing log steps, starting from 1 and ending at batch_frequency
Fazzie's avatar
Fazzie committed
405
        self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
406
407
408
409
410
411
412
413
        if not increase_log_steps:
            self.log_steps = [self.batch_freq]
        self.clamp = clamp
        self.disabled = disabled
        self.log_on_batch_idx = log_on_batch_idx
        self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
        self.log_first_step = log_first_step

NatalieC323's avatar
NatalieC323 committed
414
415
416
417
418
419
420
421
    @rank_zero_only   # Ensure that only the first process in distributed training executes this method
    def _testtube(self,         # The PyTorch Lightning module
                  pl_module,    # A dictionary of images to log.
                  images,       # 
                  batch_idx,    # The batch index.
                  split         # The split (train/val) on which to log the images
                  ):
         # Method for logging images using test-tube logger
422
423
        for k in images:
            grid = torchvision.utils.make_grid(images[k])
Fazzie's avatar
Fazzie committed
424
            grid = (grid + 1.0) / 2.0    # -1,1 -> 0,1; c,h,w
425
426

            tag = f"{split}/{k}"
NatalieC323's avatar
NatalieC323 committed
427
            # Add image grid to logger's experiment
Fazzie's avatar
Fazzie committed
428
            pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step)
429
430

    @rank_zero_only
NatalieC323's avatar
NatalieC323 committed
431
432
433
434
435
436
437
438
439
    def log_local(self,          
                  save_dir,      
                  split,         # The split (train/val) on which to log the images
                  images,        # A dictionary of images to log
                  global_step,   # The global step
                  current_epoch, # The current epoch.
                  batch_idx
                  ):
    # Method for saving image grids to local file system
440
441
442
443
        root = os.path.join(save_dir, "images", split)
        for k in images:
            grid = torchvision.utils.make_grid(images[k], nrow=4)
            if self.rescale:
Fazzie's avatar
Fazzie committed
444
                grid = (grid + 1.0) / 2.0    # -1,1 -> 0,1; c,h,w
445
446
447
            grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
            grid = grid.numpy()
            grid = (grid * 255).astype(np.uint8)
Fazzie's avatar
Fazzie committed
448
            filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
449
450
            path = os.path.join(root, filename)
            os.makedirs(os.path.split(path)[0], exist_ok=True)
NatalieC323's avatar
NatalieC323 committed
451
            # Save image grid as PNG file
452
453
454
            Image.fromarray(grid).save(path)

    def log_img(self, pl_module, batch, batch_idx, split="train"):
NatalieC323's avatar
NatalieC323 committed
455
    #Function for logging images to both the logger and local file system.
456
        check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
NatalieC323's avatar
NatalieC323 committed
457
        # check if it's time to log an image batch
Fazzie's avatar
Fazzie committed
458
459
        if (self.check_frequency(check_idx) and    # batch_idx % self.batch_freq == 0
                hasattr(pl_module, "log_images") and callable(pl_module.log_images) and self.max_images > 0):
NatalieC323's avatar
NatalieC323 committed
460
            # Get logger type and check if training mode is on
461
462
463
464
465
466
467
            logger = type(pl_module.logger)

            is_train = pl_module.training
            if is_train:
                pl_module.eval()

            with torch.no_grad():
NatalieC323's avatar
NatalieC323 committed
468
                # Get images from log_images method of the pl_module
469
470
                images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)

NatalieC323's avatar
NatalieC323 committed
471
            # Clip images if specified and convert to CPU tensor
472
473
474
475
476
477
478
479
            for k in images:
                N = min(images[k].shape[0], self.max_images)
                images[k] = images[k][:N]
                if isinstance(images[k], torch.Tensor):
                    images[k] = images[k].detach().cpu()
                    if self.clamp:
                        images[k] = torch.clamp(images[k], -1., 1.)

NatalieC323's avatar
NatalieC323 committed
480
            # Log images locally to file system
Fazzie's avatar
Fazzie committed
481
482
            self.log_local(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch,
                           batch_idx)
483

NatalieC323's avatar
NatalieC323 committed
484
            # log the images using the logger
485
486
487
            logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
            logger_log_images(pl_module, images, pl_module.global_step, split)

NatalieC323's avatar
NatalieC323 committed
488
            # switch back to training mode if necessary
489
490
491
            if is_train:
                pl_module.train()

NatalieC323's avatar
NatalieC323 committed
492
    # The function checks if it's time to log an image batch
493
    def check_frequency(self, check_idx):
Fazzie's avatar
Fazzie committed
494
495
        if ((check_idx % self.batch_freq) == 0 or
            (check_idx in self.log_steps)) and (check_idx > 0 or self.log_first_step):
496
497
498
499
500
501
502
503
            try:
                self.log_steps.pop(0)
            except IndexError as e:
                print(e)
                pass
            return True
        return False

NatalieC323's avatar
NatalieC323 committed
504
    # Log images on train batch end if logging is not disabled
505
506
507
508
509
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        # if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
        #     self.log_img(pl_module, batch, batch_idx, split="train")
        pass

NatalieC323's avatar
NatalieC323 committed
510
    # Log images on validation batch end if logging is not disabled and in validation mode
511
512
513
    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if not self.disabled and pl_module.global_step > 0:
            self.log_img(pl_module, batch, batch_idx, split="val")
NatalieC323's avatar
NatalieC323 committed
514
        # log gradients during calibration if necessary
515
516
517
518
519
520
521
522
523
524
525
        if hasattr(pl_module, 'calibrate_grad_norm'):
            if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
                self.log_gradients(trainer, pl_module, batch_idx=batch_idx)


class CUDACallback(Callback):
    # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py

    def on_train_start(self, trainer, pl_module):
        rank_zero_info("Training is starting")

NatalieC323's avatar
NatalieC323 committed
526
    #the method is called at the end of each training epoch
527
528
529
530
531
532
533
534
535
536
537
    def on_train_end(self, trainer, pl_module):
        rank_zero_info("Training is ending")

    def on_train_epoch_start(self, trainer, pl_module):
        # Reset the memory use counter
        torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
        torch.cuda.synchronize(trainer.strategy.root_device.index)
        self.start_time = time.time()

    def on_train_epoch_end(self, trainer, pl_module):
        torch.cuda.synchronize(trainer.strategy.root_device.index)
Fazzie's avatar
Fazzie committed
538
        max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2**20
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
        epoch_time = time.time() - self.start_time

        try:
            max_memory = trainer.strategy.reduce(max_memory)
            epoch_time = trainer.strategy.reduce(epoch_time)

            rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
            rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
        except AttributeError:
            pass


if __name__ == "__main__":
    # custom parser to specify config files, train, test and debug mode,
    # postfix, resume.
    # `--key value` arguments are interpreted as arguments to the trainer.
    # `nested.key=value` arguments are interpreted as config parameters.
    # configs are merged from left-to-right followed by command line parameters.

    # model:
    #   base_learning_rate: float
    #   target: path to lightning module
    #   params:
    #       key: value
    # data:
    #   target: main.DataModuleFromConfig
    #   params:
    #      batch_size: int
    #      wrap: bool
    #      train:
    #          target: path to train dataset
    #          params:
    #              key: value
    #      validation:
    #          target: path to validation dataset
    #          params:
    #              key: value
    #      test:
    #          target: path to test dataset
    #          params:
    #              key: value
580
    # lightning: (optional, has same defaults and can be specified on cmdline)
581
582
583
584
585
586
587
588
589
590
591
592
    #   trainer:
    #       additional arguments to trainer
    #   logger:
    #       logger to instantiate
    #   modelcheckpoint:
    #       modelcheckpoint to instantiate
    #   callbacks:
    #       callback1:
    #           target: importpath
    #           params:
    #               key: value

NatalieC323's avatar
NatalieC323 committed
593
    # get the current time to create a new logging directory
594
595
596
597
598
599
600
601
602
603
604
    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")

    # add cwd for convenience and to make classes in this file available when
    # running as `python main.py`
    # (in particular `main.DataModuleFromConfig`)
    sys.path.append(os.getcwd())

    parser = get_parser()
    parser = Trainer.add_argparse_args(parser)

    opt, unknown = parser.parse_known_args()
NatalieC323's avatar
NatalieC323 committed
605
    # Veirfy the arguments are both specified
606
    if opt.name and opt.resume:
Fazzie's avatar
Fazzie committed
607
608
609
        raise ValueError("-n/--name and -r/--resume cannot be specified both."
                         "If you want to resume training in a new log folder, "
                         "use -n/--name in combination with --resume_from_checkpoint")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
610

NatalieC323's avatar
NatalieC323 committed
611
    # Check if the "resume" option is specified, resume training from the checkpoint if it is true
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
612
    ckpt = None
613
    if opt.resume:
Fazzie's avatar
Fazzie committed
614
        rank_zero_info("Resuming from {}".format(opt.resume))
615
616
617
618
619
620
621
        if not os.path.exists(opt.resume):
            raise ValueError("Cannot find {}".format(opt.resume))
        if os.path.isfile(opt.resume):
            paths = opt.resume.split("/")
            # idx = len(paths)-paths[::-1].index("logs")+1
            # logdir = "/".join(paths[:idx])
            logdir = "/".join(paths[:-2])
Fazzie's avatar
Fazzie committed
622
            rank_zero_info("logdir: {}".format(logdir))
623
624
625
626
627
628
            ckpt = opt.resume
        else:
            assert os.path.isdir(opt.resume), opt.resume
            logdir = opt.resume.rstrip("/")
            ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")

NatalieC323's avatar
NatalieC323 committed
629
        # Finds all ".yaml" configuration files in the log directory and adds them to the list of base configurations
630
631
        base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
        opt.base = base_configs + opt.base
NatalieC323's avatar
NatalieC323 committed
632
        # Gets the name of the current log directory by splitting the path and taking the last element.
633
634
635
636
637
638
        _tmp = logdir.split("/")
        nowname = _tmp[-1]
    else:
        if opt.name:
            name = "_" + opt.name
        elif opt.base:
Fazzie's avatar
Fazzie committed
639
            rank_zero_info("Using base config {}".format(opt.base))
640
641
642
643
644
645
646
647
            cfg_fname = os.path.split(opt.base[0])[-1]
            cfg_name = os.path.splitext(cfg_fname)[0]
            name = "_" + cfg_name
        else:
            name = ""
        nowname = now + name + opt.postfix
        logdir = os.path.join(opt.logdir, nowname)

NatalieC323's avatar
NatalieC323 committed
648
        # Sets the checkpoint path of the 'ckpt' option is specified
Fazzie's avatar
Fazzie committed
649
650
651
        if opt.ckpt:
            ckpt = opt.ckpt

NatalieC323's avatar
NatalieC323 committed
652
    # Create the checkpoint and configuration directories within the log directory.
653
654
    ckptdir = os.path.join(logdir, "checkpoints")
    cfgdir = os.path.join(logdir, "configs")
NatalieC323's avatar
NatalieC323 committed
655
    # Sets the seed for the random number generator to ensure reproducibility
656
657
    seed_everything(opt.seed)

658
    # Intinalize and save configuration using the OmegaConf library. 
659
660
661
662
663
664
665
666
    try:
        # init and save configs
        configs = [OmegaConf.load(cfg) for cfg in opt.base]
        cli = OmegaConf.from_dotlist(unknown)
        config = OmegaConf.merge(*configs, cli)
        lightning_config = config.pop("lightning", OmegaConf.create())
        # merge trainer cli with config
        trainer_config = lightning_config.get("trainer", OmegaConf.create())
Fazzie's avatar
Fazzie committed
667

668
669
670
        for k in nondefault_trainer_args(opt):
            trainer_config[k] = getattr(opt, k)

NatalieC323's avatar
NatalieC323 committed
671
        # Check whether the accelerator is gpu
672
673
674
675
676
677
678
679
680
681
682
683
684
685
        if not trainer_config["accelerator"] == "gpu":
            del trainer_config["accelerator"]
            cpu = True
        else:
            cpu = False
        trainer_opt = argparse.Namespace(**trainer_config)
        lightning_config.trainer = trainer_config

        # model
        use_fp16 = trainer_config.get("precision", 32) == 16
        if use_fp16:
            config.model["params"].update({"use_fp16": True})
        else:
            config.model["params"].update({"use_fp16": False})
Fazzie's avatar
Fazzie committed
686
687

        if ckpt is not None:
NatalieC323's avatar
NatalieC323 committed
688
            #If a checkpoint path is specified in the ckpt variable, the code updates the "ckpt" key in the "params" dictionary of the config.model configuration with the value of ckpt
Fazzie's avatar
Fazzie committed
689
690
            config.model["params"].update({"ckpt": ckpt})
            rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"]))
Fazzie's avatar
Fazzie committed
691

692
        model = LatentDiffusion(**config.model.get("params", dict()))
693
694
695
696
        # trainer and callbacks
        trainer_kwargs = dict()

        # config the logger
NatalieC323's avatar
NatalieC323 committed
697
698
        # Default logger configs to  log training metrics during the training process.
        # These loggers are specified as targets in the dictionary, along with the configuration settings specific to each logger.
699
700
        default_logger_cfgs = {
            "wandb": {
701
                #"target": LIGHTNING_PACK_NAME + "loggers.WandbLogger",
702
703
704
705
706
707
708
                "params": {
                    "name": nowname,
                    "save_dir": logdir,
                    "offline": opt.debug,
                    "id": nowname,
                }
            },
Fazzie's avatar
Fazzie committed
709
            "tensorboard": {
710
                #"target": LIGHTNING_PACK_NAME + "loggers.TensorBoardLogger",
Fazzie's avatar
Fazzie committed
711
                "params": {
712
713
714
715
716
717
718
                    "save_dir": logdir,
                    "name": "diff_tb",
                    "log_graph": True
                }
            }
        }

NatalieC323's avatar
NatalieC323 committed
719
        # Set up the logger for TensorBoard
720
721
722
        default_logger_cfg = default_logger_cfgs["tensorboard"]
        if "logger" in lightning_config:
            logger_cfg = lightning_config.logger
723
724
            logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
            trainer_kwargs["logger"] = WandbLogger(**logger_cfg.get("params", dict()))
725
726
        else:
            logger_cfg = default_logger_cfg
727
728
729
            logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
            trainer_kwargs["logger"] = TensorBoardLogger(**logger_cfg.get("params", dict()))
        
730
731
732
733

        # config the strategy, defualt is ddp
        if "strategy" in trainer_config:
            strategy_cfg = trainer_config["strategy"]
734
            trainer_kwargs["strategy"] = ColossalAIStrategy(**strategy_cfg.get("params", dict()))
735
736
        else:
            strategy_cfg = {
737
                #"target": LIGHTNING_PACK_NAME + "strategies.DDPStrategy",
738
739
740
741
                "params": {
                    "find_unused_parameters": False
                }
            }
742
            trainer_kwargs["strategy"] = DDPStrategy(**strategy_cfg.get("params", dict()))
743

NatalieC323's avatar
NatalieC323 committed
744
        # Set up ModelCheckpoint callback to save best models
745
746
747
        # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
        # specify which metric is used to determine best models
        default_modelckpt_cfg = {
748
            #"target": LIGHTNING_PACK_NAME + "callbacks.ModelCheckpoint",
749
750
751
752
753
754
755
756
757
758
759
760
761
762
            "params": {
                "dirpath": ckptdir,
                "filename": "{epoch:06}",
                "verbose": True,
                "save_last": True,
            }
        }
        if hasattr(model, "monitor"):
            default_modelckpt_cfg["params"]["monitor"] = model.monitor
            default_modelckpt_cfg["params"]["save_top_k"] = 3

        if "modelcheckpoint" in lightning_config:
            modelckpt_cfg = lightning_config.modelcheckpoint
        else:
Fazzie's avatar
Fazzie committed
763
            modelckpt_cfg = OmegaConf.create()
764
765
        modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
        if version.parse(pl.__version__) < version.parse('1.4.0'):
766
            trainer_kwargs["checkpoint_callback"] = ModelCheckpoint(**modelckpt_cfg.get("params", dict()))
767

NatalieC323's avatar
NatalieC323 committed
768
        # Set up various callbacks, including logging, learning rate monitoring, and CUDA management
769
770
        # add callback which sets up log directory
        default_callbacks_cfg = {
NatalieC323's avatar
NatalieC323 committed
771
            "setup_callback": {                           # callback to set up the training
772
                #"target": "main.SetupCallback",
773
                "params": {
NatalieC323's avatar
NatalieC323 committed
774
775
776
777
778
779
780
                    "resume": opt.resume,                 # resume training if applicable
                    "now": now, 
                    "logdir": logdir,                     # directory to save the log file
                    "ckptdir": ckptdir,                   # directory to save the checkpoint file
                    "cfgdir": cfgdir,                     # directory to save the configuration file
                    "config": config,                     # configuration dictionary
                    "lightning_config": lightning_config, # LightningModule configuration
781
782
                }
            },
NatalieC323's avatar
NatalieC323 committed
783
            "image_logger": {                             # callback to log image data
784
                #"target": "main.ImageLogger",
785
                "params": {
NatalieC323's avatar
NatalieC323 committed
786
787
788
                    "batch_frequency": 750,               # how frequently to log images
                    "max_images": 4,                      # maximum number of images to log
                    "clamp": True                         # whether to clamp pixel values to [0,1]
789
790
                }
            },
NatalieC323's avatar
NatalieC323 committed
791
            "learning_rate_logger": {                     # callback to log learning rate
792
                #"target": "main.LearningRateMonitor",
793
                "params": {
NatalieC323's avatar
NatalieC323 committed
794
795
                    "logging_interval": "step",           # logging frequency (either 'step' or 'epoch')
        # "log_momentum": True                            # whether to log momentum (currently commented out)
796
797
                }
            },
NatalieC323's avatar
NatalieC323 committed
798
            "cuda_callback": {                            # callback to handle CUDA-related operations
799
                #"target": "main.CUDACallback"
800
801
802
            },
        }

NatalieC323's avatar
NatalieC323 committed
803
804
        # If the LightningModule configuration has specified callbacks, use those
        # Otherwise, create an empty OmegaConf configuration object
805
806
807
808
        if "callbacks" in lightning_config:
            callbacks_cfg = lightning_config.callbacks
        else:
            callbacks_cfg = OmegaConf.create()
NatalieC323's avatar
NatalieC323 committed
809
810
811
        
        # If the 'metrics_over_trainsteps_checkpoint' callback is specified in the
        # LightningModule configuration, update the default callbacks configuration
812
813
814
815
        if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
            print(
                'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
            default_metrics_over_trainsteps_ckpt_dict = {
Fazzie's avatar
Fazzie committed
816
                'metrics_over_trainsteps_checkpoint': {
817
                    #"target": LIGHTNING_PACK_NAME + 'callbacks.ModelCheckpoint',
Fazzie's avatar
Fazzie committed
818
819
820
821
822
823
824
825
826
                    'params': {
                        "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
                        "filename": "{epoch:06}-{step:09}",
                        "verbose": True,
                        'save_top_k': -1,
                        'every_n_train_steps': 10000,
                        'save_weights_only': True
                    }
                }
827
828
            }
            default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
NatalieC323's avatar
NatalieC323 committed
829
830
        
        # Merge the default callbacks configuration with the specified callbacks configuration, and instantiate the callbacks
831
        callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
832
833
834
        
        #Instantiate items according to the configs
        trainer_kwargs.setdefault("callbacks", [])
835

836
837
838
        if "setup_callback" in callbacks_cfg:
            setup_callback_config = callbacks_cfg["setup_callback"]
            trainer_kwargs["callbacks"].append(SetupCallback(**setup_callback_config.get("params", dict())))
839

840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
        if "image_logger" in callbacks_cfg:
            image_logger_config = callbacks_cfg["image_logger"]
            trainer_kwargs["callbacks"].append(ImageLogger(**image_logger_config.get("params", dict())))

        if "learning_rate_logger" in callbacks_cfg:
            learning_rate_logger_config = callbacks_cfg["learning_rate_logger"]
            trainer_kwargs["callbacks"].append(LearningRateMonitor(**learning_rate_logger_config.get("params", dict())))

        if "cuda_callback" in callbacks_cfg:
            cuda_callback_config = callbacks_cfg["cuda_callback"]
            trainer_kwargs["callbacks"].append(CUDACallback(**cuda_callback_config.get("params", dict())))

        if "metrics_over_trainsteps_checkpoint" in callbacks_cfg:
            metrics_over_config = callbacks_cfg['metrics_over_trainsteps_checkpoint']
            trainer_kwargs["callbacks"].append(ModelCheckpoint(**metrics_over_config.get("params", dict())))
        #trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
856
        trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
Fazzie's avatar
Fazzie committed
857
        trainer.logdir = logdir
858
        
NatalieC323's avatar
NatalieC323 committed
859
        # Create a data module based on the configuration file
860
        data = DataModuleFromConfig(**config.data.get("params", dict()))
861
862
863
864
865
        # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
        # calling these ourselves should not be necessary but it is.
        # lightning still takes care of proper multiprocessing though
        data.prepare_data()
        data.setup()
Fazzie's avatar
Fazzie committed
866

NatalieC323's avatar
NatalieC323 committed
867
        # Print some information about the datasets in the data module
868
        for k in data.datasets:
Fazzie's avatar
Fazzie committed
869
            rank_zero_info(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
870

NatalieC323's avatar
NatalieC323 committed
871
872
        # Configure learning rate based on the batch size, base learning rate and number of GPUs
        # If scale_lr is true, calculate the learning rate based on additional factors
873
874
875
876
877
878
879
880
881
        bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
        if not cpu:
            ngpu = trainer_config["devices"]
        else:
            ngpu = 1
        if 'accumulate_grad_batches' in lightning_config.trainer:
            accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
        else:
            accumulate_grad_batches = 1
Fazzie's avatar
Fazzie committed
882
        rank_zero_info(f"accumulate_grad_batches = {accumulate_grad_batches}")
883
884
885
        lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
        if opt.scale_lr:
            model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
Fazzie's avatar
Fazzie committed
886
            rank_zero_info(
Fazzie's avatar
Fazzie committed
887
888
                "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)"
                .format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
889
890
        else:
            model.learning_rate = base_lr
Fazzie's avatar
Fazzie committed
891
892
            rank_zero_info("++++ NOT USING LR SCALING ++++")
            rank_zero_info(f"Setting learning rate to {model.learning_rate:.2e}")
893

NatalieC323's avatar
NatalieC323 committed
894
        # Allow checkpointing via USR1
895
896
897
898
899
900
901
902
903
        def melk(*args, **kwargs):
            # run all checkpoint hooks
            if trainer.global_rank == 0:
                print("Summoning checkpoint.")
                ckpt_path = os.path.join(ckptdir, "last.ckpt")
                trainer.save_checkpoint(ckpt_path)

        def divein(*args, **kwargs):
            if trainer.global_rank == 0:
Fazzie's avatar
Fazzie committed
904
                import pudb
905
906
907
                pudb.set_trace()

        import signal
NatalieC323's avatar
NatalieC323 committed
908
        # Assign melk to SIGUSR1 signal and divein to SIGUSR2 signal
909
910
911
        signal.signal(signal.SIGUSR1, melk)
        signal.signal(signal.SIGUSR2, divein)

NatalieC323's avatar
NatalieC323 committed
912
        # Run the training and validation
913
914
915
916
917
918
        if opt.train:
            try:
                trainer.fit(model, data)
            except Exception:
                melk()
                raise
NatalieC323's avatar
NatalieC323 committed
919
920
        # Print the maximum GPU memory allocated during training
        print(f"GPU memory usage: {torch.cuda.max_memory_allocated() / 1024**2:.0f} MB")
921
922
923
        # if not opt.no_test and not trainer.interrupted:
        #     trainer.test(model, data)
    except Exception:
NatalieC323's avatar
NatalieC323 committed
924
        # If there's an exception, debug it if opt.debug is true and the trainer's global rank is 0
925
926
927
928
929
930
931
932
        if opt.debug and trainer.global_rank == 0:
            try:
                import pudb as debugger
            except ImportError:
                import pdb as debugger
            debugger.post_mortem()
        raise
    finally:
NatalieC323's avatar
NatalieC323 committed
933
        #  Move the log directory to debug_runs if opt.debug is true and the trainer's global
934
935
936
937
938
939
940
        if opt.debug and not opt.resume and trainer.global_rank == 0:
            dst, name = os.path.split(logdir)
            dst = os.path.join(dst, "debug_runs", name)
            os.makedirs(os.path.split(dst)[0], exist_ok=True)
            os.rename(logdir, dst)
        if trainer.global_rank == 0:
            print(trainer.profiler.summary())