main.py 35.2 KB
Newer Older
Fazzie's avatar
Fazzie committed
1
2
3
4
5
6
7
import argparse
import csv
import datetime
import glob
import importlib
import os
import sys
8
import time
Fazzie's avatar
Fazzie committed
9
10

import numpy as np
11
12
import torch
import torchvision
natalie_cao's avatar
natalie_cao committed
13
import lightning.pytorch as pl
14

Fazzie's avatar
Fazzie committed
15

16
from functools import partial
Fazzie's avatar
Fazzie committed
17
18
19

from omegaconf import OmegaConf
from packaging import version
20
21
from PIL import Image
from prefetch_generator import BackgroundGenerator
Fazzie's avatar
Fazzie committed
22
from torch.utils.data import DataLoader, Dataset, Subset, random_split
natalie_cao's avatar
natalie_cao committed
23
from ldm.models.diffusion.ddpm import LatentDiffusion
24

natalie_cao's avatar
natalie_cao committed
25
26
27
28
29
30
31
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.utilities import rank_zero_info, rank_zero_only
from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger
from lightning.pytorch.strategies import ColossalAIStrategy,DDPStrategy
LIGHTNING_PACK_NAME = "lightning.pytorch."
32
33
34

from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config
Fazzie's avatar
Fazzie committed
35
36
37

# from ldm.modules.attention import enable_flash_attentions

38
39

class DataLoaderX(DataLoader):
NatalieC323's avatar
NatalieC323 committed
40
# A custom data loader class that inherits from DataLoader
41
    def __iter__(self):
NatalieC323's avatar
NatalieC323 committed
42
43
        # Overriding the __iter__ method of DataLoader to return a BackgroundGenerator
        #This is to enable data laoding in the background to improve training performance
44
45
46
47
        return BackgroundGenerator(super().__iter__())


def get_parser(**parser_kwargs):
NatalieC323's avatar
NatalieC323 committed
48
    #A function to create an ArgumentParser object and add arguments to it
Fazzie's avatar
Fazzie committed
49

50
    def str2bool(v):
NatalieC323's avatar
NatalieC323 committed
51
        # A helper function to parse boolean values from command line arguments
52
53
54
55
56
57
58
59
        if isinstance(v, bool):
            return v
        if v.lower() in ("yes", "true", "t", "y", "1"):
            return True
        elif v.lower() in ("no", "false", "f", "n", "0"):
            return False
        else:
            raise argparse.ArgumentTypeError("Boolean value expected.")
NatalieC323's avatar
NatalieC323 committed
60
    # Create an ArgumentParser object with specifies kwargs
61
    parser = argparse.ArgumentParser(**parser_kwargs)
NatalieC323's avatar
NatalieC323 committed
62
63

    # Add vairous command line arguments with their default balues and descriptions
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    parser.add_argument(
        "-n",
        "--name",
        type=str,
        const=True,
        default="",
        nargs="?",
        help="postfix for logdir",
    )
    parser.add_argument(
        "-r",
        "--resume",
        type=str,
        const=True,
        default="",
        nargs="?",
        help="resume from logdir or checkpoint in logdir",
    )
    parser.add_argument(
        "-b",
        "--base",
        nargs="*",
        metavar="base_config.yaml",
        help="paths to base configs. Loaded from left-to-right. "
Fazzie's avatar
Fazzie committed
88
        "Parameters can be overwritten or added with command-line options of the form `--key value`.",
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        default=list(),
    )
    parser.add_argument(
        "-t",
        "--train",
        type=str2bool,
        const=True,
        default=False,
        nargs="?",
        help="train",
    )
    parser.add_argument(
        "--no-test",
        type=str2bool,
        const=True,
        default=False,
        nargs="?",
        help="disable test",
    )
Fazzie's avatar
Fazzie committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    parser.add_argument(
        "-p",
        "--project",
        help="name of new or path to existing project",
    )
    parser.add_argument(
        "-c",
        "--ckpt",
        type=str,
        const=True,
        default="",
        nargs="?",
        help="load pretrained checkpoint from stable AI",
    )
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    parser.add_argument(
        "-d",
        "--debug",
        type=str2bool,
        nargs="?",
        const=True,
        default=False,
        help="enable post-mortem debugging",
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        default=23,
        help="seed for seed_everything",
    )
    parser.add_argument(
        "-f",
        "--postfix",
        type=str,
        default="",
        help="post-postfix for default name",
    )
    parser.add_argument(
        "-l",
        "--logdir",
        type=str,
        default="logs",
        help="directory for logging dat shit",
    )
    parser.add_argument(
        "--scale_lr",
        type=str2bool,
        nargs="?",
        const=True,
        default=True,
        help="scale base-lr by ngpu * batch_size * n_accumulate",
    )
Fazzie's avatar
Fazzie committed
160

161
162
    return parser

NatalieC323's avatar
NatalieC323 committed
163
# A function that returns the non-default arguments between two objects
164
def nondefault_trainer_args(opt):
NatalieC323's avatar
NatalieC323 committed
165
    # create an argument parsser
166
    parser = argparse.ArgumentParser()
NatalieC323's avatar
NatalieC323 committed
167
    # add pytorch lightning trainer default arguments
168
    parser = Trainer.add_argparse_args(parser)
NatalieC323's avatar
NatalieC323 committed
169
    # parse the empty arguments to obtain the default values
170
    args = parser.parse_args([])
NatalieC323's avatar
NatalieC323 committed
171
    # return all non-default arguments
172
173
    return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))

NatalieC323's avatar
NatalieC323 committed
174
# A dataset wrapper class to create a pytorch dataset from an arbitrary object
175
176
177
178
179
180
181
182
183
184
185
186
class WrappedDataset(Dataset):
    """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""

    def __init__(self, dataset):
        self.data = dataset

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

NatalieC323's avatar
NatalieC323 committed
187
# A function to initialize worker processes
188
189
190
191
192
193
194
def worker_init_fn(_):
    worker_info = torch.utils.data.get_worker_info()

    dataset = worker_info.dataset
    worker_id = worker_info.id

    if isinstance(dataset, Txt2ImgIterableBaseDataset):
NatalieC323's avatar
NatalieC323 committed
195
        #divide the dataset into equal parts for each worker
196
        split_size = dataset.num_records // worker_info.num_workers
NatalieC323's avatar
NatalieC323 committed
197
        #set the sample IDs for the current worker
198
199
        # reset num_records to the true number to retain reliable length information
        dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
NatalieC323's avatar
NatalieC323 committed
200
        # set the seed for the current worker
201
202
203
204
205
        current_id = np.random.choice(len(np.random.get_state()[1]), 1)
        return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
    else:
        return np.random.seed(np.random.get_state()[1][0] + worker_id)

NatalieC323's avatar
NatalieC323 committed
206
#Provide functionality for creating data loadedrs based on provided dataset configurations
207
class DataModuleFromConfig(pl.LightningDataModule):
Fazzie's avatar
Fazzie committed
208
209
210
211
212
213
214
215
216
217
218

    def __init__(self,
                 batch_size,
                 train=None,
                 validation=None,
                 test=None,
                 predict=None,
                 wrap=False,
                 num_workers=None,
                 shuffle_test_loader=False,
                 use_worker_init_fn=False,
219
220
                 shuffle_val_dataloader=False):
        super().__init__()
NatalieC323's avatar
NatalieC323 committed
221
        # Set data module attributes
222
223
224
225
        self.batch_size = batch_size
        self.dataset_configs = dict()
        self.num_workers = num_workers if num_workers is not None else batch_size * 2
        self.use_worker_init_fn = use_worker_init_fn
NatalieC323's avatar
NatalieC323 committed
226
        # If a dataset is passed, add it to the dataset configs and create a corresponding dataloader method
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
        if train is not None:
            self.dataset_configs["train"] = train
            self.train_dataloader = self._train_dataloader
        if validation is not None:
            self.dataset_configs["validation"] = validation
            self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
        if test is not None:
            self.dataset_configs["test"] = test
            self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
        if predict is not None:
            self.dataset_configs["predict"] = predict
            self.predict_dataloader = self._predict_dataloader
        self.wrap = wrap

    def prepare_data(self):
NatalieC323's avatar
NatalieC323 committed
242
        # Instantiate datasets
243
244
245
246
        for data_cfg in self.dataset_configs.values():
            instantiate_from_config(data_cfg)

    def setup(self, stage=None):
NatalieC323's avatar
NatalieC323 committed
247
        # Instantiate datasets from the dataset configs
Fazzie's avatar
Fazzie committed
248
        self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
NatalieC323's avatar
NatalieC323 committed
249
250
        
        # If wrap is true, create a WrappedDataset for each dataset
251
252
253
254
255
        if self.wrap:
            for k in self.datasets:
                self.datasets[k] = WrappedDataset(self.datasets[k])

    def _train_dataloader(self):
NatalieC323's avatar
NatalieC323 committed
256
        #Check if the train dataset is iterable
257
        is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
NatalieC323's avatar
NatalieC323 committed
258
        #Set the worker initialization function of the dataset isiterable or use_worker_init_fn is True
259
260
261
262
        if is_iterable_dataset or self.use_worker_init_fn:
            init_fn = worker_init_fn
        else:
            init_fn = None
NatalieC323's avatar
NatalieC323 committed
263
        # Return a DataLoaderX object for the train dataset
Fazzie's avatar
Fazzie committed
264
265
266
267
268
        return DataLoaderX(self.datasets["train"],
                           batch_size=self.batch_size,
                           num_workers=self.num_workers,
                           shuffle=False if is_iterable_dataset else True,
                           worker_init_fn=init_fn)
269
270

    def _val_dataloader(self, shuffle=False):
NatalieC323's avatar
NatalieC323 committed
271
        #Check if the validation dataset is iterable
272
273
274
275
        if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
            init_fn = worker_init_fn
        else:
            init_fn = None
NatalieC323's avatar
NatalieC323 committed
276
        # Return a DataLoaderX object for the validation dataset
277
        return DataLoaderX(self.datasets["validation"],
Fazzie's avatar
Fazzie committed
278
279
280
281
                           batch_size=self.batch_size,
                           num_workers=self.num_workers,
                           worker_init_fn=init_fn,
                           shuffle=shuffle)
282
283

    def _test_dataloader(self, shuffle=False):
NatalieC323's avatar
NatalieC323 committed
284
        # Check if the test dataset is iterable
285
        is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
NatalieC323's avatar
NatalieC323 committed
286
        # Set the worker initialization function if the dataset is iterable or use_worker_init_fn is True
287
288
289
290
291
292
293
294
        if is_iterable_dataset or self.use_worker_init_fn:
            init_fn = worker_init_fn
        else:
            init_fn = None

        # do not shuffle dataloader for iterable dataset
        shuffle = shuffle and (not is_iterable_dataset)

Fazzie's avatar
Fazzie committed
295
296
297
298
299
        return DataLoaderX(self.datasets["test"],
                           batch_size=self.batch_size,
                           num_workers=self.num_workers,
                           worker_init_fn=init_fn,
                           shuffle=shuffle)
300
301
302
303
304
305

    def _predict_dataloader(self, shuffle=False):
        if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
            init_fn = worker_init_fn
        else:
            init_fn = None
Fazzie's avatar
Fazzie committed
306
307
308
309
        return DataLoaderX(self.datasets["predict"],
                           batch_size=self.batch_size,
                           num_workers=self.num_workers,
                           worker_init_fn=init_fn)
310
311
312


class SetupCallback(Callback):
NatalieC323's avatar
NatalieC323 committed
313
    # I nitialize the callback with the necessary parameters
Fazzie's avatar
Fazzie committed
314

315
316
317
318
319
320
321
322
323
324
    def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
        super().__init__()
        self.resume = resume
        self.now = now
        self.logdir = logdir
        self.ckptdir = ckptdir
        self.cfgdir = cfgdir
        self.config = config
        self.lightning_config = lightning_config

NatalieC323's avatar
NatalieC323 committed
325
    # Save a checkpoint if training is interrupted with keyboard interrupt
326
327
328
329
330
331
    def on_keyboard_interrupt(self, trainer, pl_module):
        if trainer.global_rank == 0:
            print("Summoning checkpoint.")
            ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
            trainer.save_checkpoint(ckpt_path)

NatalieC323's avatar
NatalieC323 committed
332
    # Create necessary directories and save configuration files before training starts
333
334
335
336
337
338
339
340
    # def on_pretrain_routine_start(self, trainer, pl_module):
    def on_fit_start(self, trainer, pl_module):
        if trainer.global_rank == 0:
            # Create logdirs and save configs
            os.makedirs(self.logdir, exist_ok=True)
            os.makedirs(self.ckptdir, exist_ok=True)
            os.makedirs(self.cfgdir, exist_ok=True)

NatalieC323's avatar
NatalieC323 committed
341
            #Create trainstep checkpoint directory if necessary
342
343
344
345
346
            if "callbacks" in self.lightning_config:
                if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:
                    os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
            print("Project config")
            print(OmegaConf.to_yaml(self.config))
Fazzie's avatar
Fazzie committed
347
            OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
348

NatalieC323's avatar
NatalieC323 committed
349
            # Save project config and lightning config as YAML files
350
351
352
353
354
            print("Lightning config")
            print(OmegaConf.to_yaml(self.lightning_config))
            OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
                           os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))

NatalieC323's avatar
NatalieC323 committed
355
        # Remove log directory if resuming training and directory already exists
356
357
358
359
360
361
362
363
364
365
366
        else:
            # ModelCheckpoint callback created log directory --- remove it
            if not self.resume and os.path.exists(self.logdir):
                dst, name = os.path.split(self.logdir)
                dst = os.path.join(dst, "child_runs", name)
                os.makedirs(os.path.split(dst)[0], exist_ok=True)
                try:
                    os.rename(self.logdir, dst)
                except FileNotFoundError:
                    pass

Fazzie's avatar
Fazzie committed
367
368
369
370
371
372
    # def on_fit_end(self, trainer, pl_module):
    #     if trainer.global_rank == 0:
    #         ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
    #         rank_zero_info(f"Saving final checkpoint in {ckpt_path}.")
    #         trainer.save_checkpoint(ckpt_path)

373

NatalieC323's avatar
NatalieC323 committed
374
# PyTorch Lightning callback for ogging images during training and validation of a deep learning model
375
class ImageLogger(Callback):
Fazzie's avatar
Fazzie committed
376
377

    def __init__(self,
NatalieC323's avatar
NatalieC323 committed
378
379
380
381
382
383
384
385
386
                 batch_frequency, # Frequency of batches on which to log images
                 max_images,      # Maximum number of images to log
                 clamp=True,      # Whether to clamp pixel values to [-1,1]
                 increase_log_steps=True,   # Whether to increase frequency of log steps exponentially
                 rescale=True,    # Whetehr to rescale pixel values to [0,1]
                 disabled=False,  # Whether to disable logging
                 log_on_batch_idx=False,   # Whether to log on baych index instead of global step
                 log_first_step=False,     # Whetehr to log on the first step
                 log_images_kwargs=None):  # Additional keyword arguments to pass to log_images method
387
388
389
390
391
        super().__init__()
        self.rescale = rescale
        self.batch_freq = batch_frequency
        self.max_images = max_images
        self.logger_log_images = {
NatalieC323's avatar
NatalieC323 committed
392
393
            # Dictionary of logger classes and their corresponding logging methods
            pl.loggers.CSVLogger: self._testtube,   
394
        }
NatalieC323's avatar
NatalieC323 committed
395
        # Create a list of exponentially increasing log steps, starting from 1 and ending at batch_frequency
Fazzie's avatar
Fazzie committed
396
        self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
397
398
399
400
401
402
403
404
        if not increase_log_steps:
            self.log_steps = [self.batch_freq]
        self.clamp = clamp
        self.disabled = disabled
        self.log_on_batch_idx = log_on_batch_idx
        self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
        self.log_first_step = log_first_step

NatalieC323's avatar
NatalieC323 committed
405
406
407
408
409
410
411
412
    @rank_zero_only   # Ensure that only the first process in distributed training executes this method
    def _testtube(self,         # The PyTorch Lightning module
                  pl_module,    # A dictionary of images to log.
                  images,       # 
                  batch_idx,    # The batch index.
                  split         # The split (train/val) on which to log the images
                  ):
         # Method for logging images using test-tube logger
413
414
        for k in images:
            grid = torchvision.utils.make_grid(images[k])
Fazzie's avatar
Fazzie committed
415
            grid = (grid + 1.0) / 2.0    # -1,1 -> 0,1; c,h,w
416
417

            tag = f"{split}/{k}"
NatalieC323's avatar
NatalieC323 committed
418
            # Add image grid to logger's experiment
Fazzie's avatar
Fazzie committed
419
            pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step)
420
421

    @rank_zero_only
NatalieC323's avatar
NatalieC323 committed
422
423
424
425
426
427
428
429
430
    def log_local(self,          
                  save_dir,      
                  split,         # The split (train/val) on which to log the images
                  images,        # A dictionary of images to log
                  global_step,   # The global step
                  current_epoch, # The current epoch.
                  batch_idx
                  ):
    # Method for saving image grids to local file system
431
432
433
434
        root = os.path.join(save_dir, "images", split)
        for k in images:
            grid = torchvision.utils.make_grid(images[k], nrow=4)
            if self.rescale:
Fazzie's avatar
Fazzie committed
435
                grid = (grid + 1.0) / 2.0    # -1,1 -> 0,1; c,h,w
436
437
438
            grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
            grid = grid.numpy()
            grid = (grid * 255).astype(np.uint8)
Fazzie's avatar
Fazzie committed
439
            filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
440
441
            path = os.path.join(root, filename)
            os.makedirs(os.path.split(path)[0], exist_ok=True)
NatalieC323's avatar
NatalieC323 committed
442
            # Save image grid as PNG file
443
444
445
            Image.fromarray(grid).save(path)

    def log_img(self, pl_module, batch, batch_idx, split="train"):
NatalieC323's avatar
NatalieC323 committed
446
    #Function for logging images to both the logger and local file system.
447
        check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
NatalieC323's avatar
NatalieC323 committed
448
        # check if it's time to log an image batch
Fazzie's avatar
Fazzie committed
449
450
        if (self.check_frequency(check_idx) and    # batch_idx % self.batch_freq == 0
                hasattr(pl_module, "log_images") and callable(pl_module.log_images) and self.max_images > 0):
NatalieC323's avatar
NatalieC323 committed
451
            # Get logger type and check if training mode is on
452
453
454
455
456
457
458
            logger = type(pl_module.logger)

            is_train = pl_module.training
            if is_train:
                pl_module.eval()

            with torch.no_grad():
NatalieC323's avatar
NatalieC323 committed
459
                # Get images from log_images method of the pl_module
460
461
                images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)

NatalieC323's avatar
NatalieC323 committed
462
            # Clip images if specified and convert to CPU tensor
463
464
465
466
467
468
469
470
            for k in images:
                N = min(images[k].shape[0], self.max_images)
                images[k] = images[k][:N]
                if isinstance(images[k], torch.Tensor):
                    images[k] = images[k].detach().cpu()
                    if self.clamp:
                        images[k] = torch.clamp(images[k], -1., 1.)

NatalieC323's avatar
NatalieC323 committed
471
            # Log images locally to file system
Fazzie's avatar
Fazzie committed
472
473
            self.log_local(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch,
                           batch_idx)
474

NatalieC323's avatar
NatalieC323 committed
475
            # log the images using the logger
476
477
478
            logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
            logger_log_images(pl_module, images, pl_module.global_step, split)

NatalieC323's avatar
NatalieC323 committed
479
            # switch back to training mode if necessary
480
481
482
            if is_train:
                pl_module.train()

NatalieC323's avatar
NatalieC323 committed
483
    # The function checks if it's time to log an image batch
484
    def check_frequency(self, check_idx):
Fazzie's avatar
Fazzie committed
485
486
        if ((check_idx % self.batch_freq) == 0 or
            (check_idx in self.log_steps)) and (check_idx > 0 or self.log_first_step):
487
488
489
490
491
492
493
494
            try:
                self.log_steps.pop(0)
            except IndexError as e:
                print(e)
                pass
            return True
        return False

NatalieC323's avatar
NatalieC323 committed
495
    # Log images on train batch end if logging is not disabled
496
497
498
499
500
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        # if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
        #     self.log_img(pl_module, batch, batch_idx, split="train")
        pass

NatalieC323's avatar
NatalieC323 committed
501
    # Log images on validation batch end if logging is not disabled and in validation mode
502
503
504
    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if not self.disabled and pl_module.global_step > 0:
            self.log_img(pl_module, batch, batch_idx, split="val")
NatalieC323's avatar
NatalieC323 committed
505
        # log gradients during calibration if necessary
506
507
508
509
510
511
512
513
514
515
516
        if hasattr(pl_module, 'calibrate_grad_norm'):
            if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
                self.log_gradients(trainer, pl_module, batch_idx=batch_idx)


class CUDACallback(Callback):
    # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py

    def on_train_start(self, trainer, pl_module):
        rank_zero_info("Training is starting")

NatalieC323's avatar
NatalieC323 committed
517
    #the method is called at the end of each training epoch
518
519
520
521
522
523
524
525
526
527
528
    def on_train_end(self, trainer, pl_module):
        rank_zero_info("Training is ending")

    def on_train_epoch_start(self, trainer, pl_module):
        # Reset the memory use counter
        torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
        torch.cuda.synchronize(trainer.strategy.root_device.index)
        self.start_time = time.time()

    def on_train_epoch_end(self, trainer, pl_module):
        torch.cuda.synchronize(trainer.strategy.root_device.index)
Fazzie's avatar
Fazzie committed
529
        max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2**20
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
        epoch_time = time.time() - self.start_time

        try:
            max_memory = trainer.strategy.reduce(max_memory)
            epoch_time = trainer.strategy.reduce(epoch_time)

            rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
            rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
        except AttributeError:
            pass


if __name__ == "__main__":
    # custom parser to specify config files, train, test and debug mode,
    # postfix, resume.
    # `--key value` arguments are interpreted as arguments to the trainer.
    # `nested.key=value` arguments are interpreted as config parameters.
    # configs are merged from left-to-right followed by command line parameters.

    # model:
    #   base_learning_rate: float
    #   target: path to lightning module
    #   params:
    #       key: value
    # data:
    #   target: main.DataModuleFromConfig
    #   params:
    #      batch_size: int
    #      wrap: bool
    #      train:
    #          target: path to train dataset
    #          params:
    #              key: value
    #      validation:
    #          target: path to validation dataset
    #          params:
    #              key: value
    #      test:
    #          target: path to test dataset
    #          params:
    #              key: value
571
    # lightning: (optional, has sane defaults and can be specified on cmdline)
572
573
574
575
576
577
578
579
580
581
582
583
    #   trainer:
    #       additional arguments to trainer
    #   logger:
    #       logger to instantiate
    #   modelcheckpoint:
    #       modelcheckpoint to instantiate
    #   callbacks:
    #       callback1:
    #           target: importpath
    #           params:
    #               key: value

NatalieC323's avatar
NatalieC323 committed
584
    # get the current time to create a new logging directory
585
586
587
588
589
590
591
592
593
594
595
    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")

    # add cwd for convenience and to make classes in this file available when
    # running as `python main.py`
    # (in particular `main.DataModuleFromConfig`)
    sys.path.append(os.getcwd())

    parser = get_parser()
    parser = Trainer.add_argparse_args(parser)

    opt, unknown = parser.parse_known_args()
NatalieC323's avatar
NatalieC323 committed
596
    # Veirfy the arguments are both specified
597
    if opt.name and opt.resume:
Fazzie's avatar
Fazzie committed
598
599
600
        raise ValueError("-n/--name and -r/--resume cannot be specified both."
                         "If you want to resume training in a new log folder, "
                         "use -n/--name in combination with --resume_from_checkpoint")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
601

NatalieC323's avatar
NatalieC323 committed
602
    # Check if the "resume" option is specified, resume training from the checkpoint if it is true
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
603
    ckpt = None
604
    if opt.resume:
Fazzie's avatar
Fazzie committed
605
        rank_zero_info("Resuming from {}".format(opt.resume))
606
607
608
609
610
611
612
        if not os.path.exists(opt.resume):
            raise ValueError("Cannot find {}".format(opt.resume))
        if os.path.isfile(opt.resume):
            paths = opt.resume.split("/")
            # idx = len(paths)-paths[::-1].index("logs")+1
            # logdir = "/".join(paths[:idx])
            logdir = "/".join(paths[:-2])
Fazzie's avatar
Fazzie committed
613
            rank_zero_info("logdir: {}".format(logdir))
614
615
616
617
618
619
            ckpt = opt.resume
        else:
            assert os.path.isdir(opt.resume), opt.resume
            logdir = opt.resume.rstrip("/")
            ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")

NatalieC323's avatar
NatalieC323 committed
620
        # Finds all ".yaml" configuration files in the log directory and adds them to the list of base configurations
621
622
        base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
        opt.base = base_configs + opt.base
NatalieC323's avatar
NatalieC323 committed
623
        # Gets the name of the current log directory by splitting the path and taking the last element.
624
625
626
627
628
629
        _tmp = logdir.split("/")
        nowname = _tmp[-1]
    else:
        if opt.name:
            name = "_" + opt.name
        elif opt.base:
Fazzie's avatar
Fazzie committed
630
            rank_zero_info("Using base config {}".format(opt.base))
631
632
633
634
635
636
637
638
            cfg_fname = os.path.split(opt.base[0])[-1]
            cfg_name = os.path.splitext(cfg_fname)[0]
            name = "_" + cfg_name
        else:
            name = ""
        nowname = now + name + opt.postfix
        logdir = os.path.join(opt.logdir, nowname)

NatalieC323's avatar
NatalieC323 committed
639
        # Sets the checkpoint path of the 'ckpt' option is specified
Fazzie's avatar
Fazzie committed
640
641
642
        if opt.ckpt:
            ckpt = opt.ckpt

NatalieC323's avatar
NatalieC323 committed
643
    # Create the checkpoint and configuration directories within the log directory.
644
645
    ckptdir = os.path.join(logdir, "checkpoints")
    cfgdir = os.path.join(logdir, "configs")
NatalieC323's avatar
NatalieC323 committed
646
    # Sets the seed for the random number generator to ensure reproducibility
647
648
    seed_everything(opt.seed)

649
    # Intinalize and save configuratioon using teh OmegaConf library. 
650
651
652
653
654
655
656
657
    try:
        # init and save configs
        configs = [OmegaConf.load(cfg) for cfg in opt.base]
        cli = OmegaConf.from_dotlist(unknown)
        config = OmegaConf.merge(*configs, cli)
        lightning_config = config.pop("lightning", OmegaConf.create())
        # merge trainer cli with config
        trainer_config = lightning_config.get("trainer", OmegaConf.create())
Fazzie's avatar
Fazzie committed
658

659
660
661
        for k in nondefault_trainer_args(opt):
            trainer_config[k] = getattr(opt, k)

NatalieC323's avatar
NatalieC323 committed
662
        # Check whether the accelerator is gpu
663
664
665
666
667
668
669
670
671
672
673
674
675
676
        if not trainer_config["accelerator"] == "gpu":
            del trainer_config["accelerator"]
            cpu = True
        else:
            cpu = False
        trainer_opt = argparse.Namespace(**trainer_config)
        lightning_config.trainer = trainer_config

        # model
        use_fp16 = trainer_config.get("precision", 32) == 16
        if use_fp16:
            config.model["params"].update({"use_fp16": True})
        else:
            config.model["params"].update({"use_fp16": False})
Fazzie's avatar
Fazzie committed
677
678

        if ckpt is not None:
NatalieC323's avatar
NatalieC323 committed
679
            #If a checkpoint path is specified in the ckpt variable, the code updates the "ckpt" key in the "params" dictionary of the config.model configuration with the value of ckpt
Fazzie's avatar
Fazzie committed
680
681
            config.model["params"].update({"ckpt": ckpt})
            rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"]))
Fazzie's avatar
Fazzie committed
682

natalie_cao's avatar
natalie_cao committed
683
        model = LatentDiffusion(**config.model.get("params", dict()))
684
685
686
687
        # trainer and callbacks
        trainer_kwargs = dict()

        # config the logger
NatalieC323's avatar
NatalieC323 committed
688
        # Default logger configs to  log training metrics during the training process.
689
690
691
692
693
694
695
        default_logger_cfgs = {
            "wandb": {
                    "name": nowname,
                    "save_dir": logdir,
                    "offline": opt.debug,
                    "id": nowname,
                }
natalie_cao's avatar
natalie_cao committed
696
            ,
Fazzie's avatar
Fazzie committed
697
            "tensorboard": {
698
699
700
701
702
703
                    "save_dir": logdir,
                    "name": "diff_tb",
                    "log_graph": True
                }
        }

NatalieC323's avatar
NatalieC323 committed
704
        # Set up the logger for TensorBoard
705
706
707
        default_logger_cfg = default_logger_cfgs["tensorboard"]
        if "logger" in lightning_config:
            logger_cfg = lightning_config.logger
natalie_cao's avatar
natalie_cao committed
708
            trainer_kwargs["logger"] = WandbLogger(**logger_cfg)
709
710
        else:
            logger_cfg = default_logger_cfg
natalie_cao's avatar
natalie_cao committed
711
            trainer_kwargs["logger"] = TensorBoardLogger(**logger_cfg)
712
713
714
715

        # config the strategy, defualt is ddp
        if "strategy" in trainer_config:
            strategy_cfg = trainer_config["strategy"]
natalie_cao's avatar
natalie_cao committed
716
            trainer_kwargs["strategy"] = ColossalAIStrategy(**strategy_cfg)
717
        else:
natalie_cao's avatar
natalie_cao committed
718
719
            strategy_cfg = {"find_unused_parameters": False}
            trainer_kwargs["strategy"] = DDPStrategy(**strategy_cfg)
720

NatalieC323's avatar
NatalieC323 committed
721
        # Set up ModelCheckpoint callback to save best models
722
723
724
725
726
727
728
729
730
        # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
        # specify which metric is used to determine best models
        default_modelckpt_cfg = {
                "dirpath": ckptdir,
                "filename": "{epoch:06}",
                "verbose": True,
                "save_last": True,
            }
        if hasattr(model, "monitor"):
natalie_cao's avatar
natalie_cao committed
731
732
            default_modelckpt_cfg["monitor"] = model.monitor
            default_modelckpt_cfg["save_top_k"] = 3
733
734

        if "modelcheckpoint" in lightning_config:
natalie_cao's avatar
natalie_cao committed
735
            modelckpt_cfg = lightning_config.modelcheckpoint["params"]
736
        else:
Fazzie's avatar
Fazzie committed
737
            modelckpt_cfg = OmegaConf.create()
738
739
        modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
        if version.parse(pl.__version__) < version.parse('1.4.0'):
natalie_cao's avatar
natalie_cao committed
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
            trainer_kwargs["checkpoint_callback"] = ModelCheckpoint(**modelckpt_cfg)

        #Create an empty OmegaConf configuration object

        callbacks_cfg = OmegaConf.create()
       
        #Instantiate items according to the configs
        trainer_kwargs.setdefault("callbacks", [])
        setup_callback_config = {
            "resume": opt.resume,                 # resume training if applicable
            "now": now, 
            "logdir": logdir,                     # directory to save the log file
            "ckptdir": ckptdir,                   # directory to save the checkpoint file
            "cfgdir": cfgdir,                     # directory to save the configuration file
            "config": config,                     # configuration dictionary
            "lightning_config": lightning_config, # LightningModule configuration
            }
        trainer_kwargs["callbacks"].append(SetupCallback(**setup_callback_config))
NatalieC323's avatar
NatalieC323 committed
758
        
natalie_cao's avatar
natalie_cao committed
759
760
761
762
763
        image_logger_config = {
    
            "batch_frequency": 750,               # how frequently to log images
            "max_images": 4,                      # maximum number of images to log
            "clamp": True                         # whether to clamp pixel values to [0,1]
764
            }
natalie_cao's avatar
natalie_cao committed
765
        trainer_kwargs["callbacks"].append(ImageLogger(**image_logger_config))
NatalieC323's avatar
NatalieC323 committed
766
        
natalie_cao's avatar
natalie_cao committed
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
        learning_rate_logger_config = {
            "logging_interval": "step",           # logging frequency (either 'step' or 'epoch')
        # "log_momentum": True                            # whether to log momentum (currently commented out)
            }
        trainer_kwargs["callbacks"].append(LearningRateMonitor(**learning_rate_logger_config))
        
        metrics_over_trainsteps_checkpoint_config= {
            "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
            "filename": "{epoch:06}-{step:09}",
            "verbose": True,
            'save_top_k': -1,
            'every_n_train_steps': 10000,
            'save_weights_only': True
            }
        trainer_kwargs["callbacks"].append(ModelCheckpoint(**metrics_over_trainsteps_checkpoint_config))
        trainer_kwargs["callbacks"].append(CUDACallback())
783

784
        # Create a Trainer object with the specified command-line arguments and keyword arguments, and set the log directory
785
        trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
Fazzie's avatar
Fazzie committed
786
        trainer.logdir = logdir
787

NatalieC323's avatar
NatalieC323 committed
788
        # Create a data module based on the configuration file
natalie_cao's avatar
natalie_cao committed
789
790
        data = DataModuleFromConfig(**config.data)

791
792
793
794
795
        # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
        # calling these ourselves should not be necessary but it is.
        # lightning still takes care of proper multiprocessing though
        data.prepare_data()
        data.setup()
Fazzie's avatar
Fazzie committed
796

NatalieC323's avatar
NatalieC323 committed
797
        # Print some information about the datasets in the data module
798
        for k in data.datasets:
Fazzie's avatar
Fazzie committed
799
            rank_zero_info(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
800

NatalieC323's avatar
NatalieC323 committed
801
802
        # Configure learning rate based on the batch size, base learning rate and number of GPUs
        # If scale_lr is true, calculate the learning rate based on additional factors
natalie_cao's avatar
natalie_cao committed
803
        bs, base_lr = config.data.batch_size, config.model.base_learning_rate
804
805
806
807
808
809
810
811
        if not cpu:
            ngpu = trainer_config["devices"]
        else:
            ngpu = 1
        if 'accumulate_grad_batches' in lightning_config.trainer:
            accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
        else:
            accumulate_grad_batches = 1
Fazzie's avatar
Fazzie committed
812
        rank_zero_info(f"accumulate_grad_batches = {accumulate_grad_batches}")
813
814
815
        lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
        if opt.scale_lr:
            model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
Fazzie's avatar
Fazzie committed
816
            rank_zero_info(
Fazzie's avatar
Fazzie committed
817
818
                "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)"
                .format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
819
820
        else:
            model.learning_rate = base_lr
Fazzie's avatar
Fazzie committed
821
822
            rank_zero_info("++++ NOT USING LR SCALING ++++")
            rank_zero_info(f"Setting learning rate to {model.learning_rate:.2e}")
823

NatalieC323's avatar
NatalieC323 committed
824
        # Allow checkpointing via USR1
825
826
827
828
829
830
831
832
833
        def melk(*args, **kwargs):
            # run all checkpoint hooks
            if trainer.global_rank == 0:
                print("Summoning checkpoint.")
                ckpt_path = os.path.join(ckptdir, "last.ckpt")
                trainer.save_checkpoint(ckpt_path)

        def divein(*args, **kwargs):
            if trainer.global_rank == 0:
Fazzie's avatar
Fazzie committed
834
                import pudb
835
836
837
                pudb.set_trace()

        import signal
NatalieC323's avatar
NatalieC323 committed
838
        # Assign melk to SIGUSR1 signal and divein to SIGUSR2 signal
839
840
841
        signal.signal(signal.SIGUSR1, melk)
        signal.signal(signal.SIGUSR2, divein)

NatalieC323's avatar
NatalieC323 committed
842
        # Run the training and validation
843
844
845
846
847
848
        if opt.train:
            try:
                trainer.fit(model, data)
            except Exception:
                melk()
                raise
NatalieC323's avatar
NatalieC323 committed
849
850
        # Print the maximum GPU memory allocated during training
        print(f"GPU memory usage: {torch.cuda.max_memory_allocated() / 1024**2:.0f} MB")
851
852
853
        # if not opt.no_test and not trainer.interrupted:
        #     trainer.test(model, data)
    except Exception:
NatalieC323's avatar
NatalieC323 committed
854
        # If there's an exception, debug it if opt.debug is true and the trainer's global rank is 0
855
856
857
858
859
860
861
862
        if opt.debug and trainer.global_rank == 0:
            try:
                import pudb as debugger
            except ImportError:
                import pdb as debugger
            debugger.post_mortem()
        raise
    finally:
NatalieC323's avatar
NatalieC323 committed
863
        #  Move the log directory to debug_runs if opt.debug is true and the trainer's global
864
865
866
867
868
869
870
        if opt.debug and not opt.resume and trainer.global_rank == 0:
            dst, name = os.path.split(logdir)
            dst = os.path.join(dst, "debug_runs", name)
            os.makedirs(os.path.split(dst)[0], exist_ok=True)
            os.rename(logdir, dst)
        if trainer.global_rank == 0:
            print(trainer.profiler.summary())