main.py 34.4 KB
Newer Older
Fazzie's avatar
Fazzie committed
1
2
3
4
5
import argparse
import datetime
import glob
import os
import sys
6
import time
7
from functools import partial
Fazzie's avatar
Fazzie committed
8

9
import lightning.pytorch as pl
Fazzie's avatar
Fazzie committed
10
import numpy as np
11
12
import torch
import torchvision
natalie_cao's avatar
natalie_cao committed
13
14
15
from ldm.models.diffusion.ddpm import LatentDiffusion
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
16
17
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
from lightning.pytorch.strategies import ColossalAIStrategy, DDPStrategy
natalie_cao's avatar
natalie_cao committed
18
19
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.utilities import rank_zero_info, rank_zero_only
20
21
22
23
24
25
from omegaconf import OmegaConf
from packaging import version
from PIL import Image
from prefetch_generator import BackgroundGenerator
from torch.utils.data import DataLoader, Dataset

natalie_cao's avatar
natalie_cao committed
26
LIGHTNING_PACK_NAME = "lightning.pytorch."
27
28
29

from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config
Fazzie's avatar
Fazzie committed
30
31
32

# from ldm.modules.attention import enable_flash_attentions

33
34

class DataLoaderX(DataLoader):
35
    # A custom data loader class that inherits from DataLoader
36
    def __iter__(self):
NatalieC323's avatar
NatalieC323 committed
37
        # Overriding the __iter__ method of DataLoader to return a BackgroundGenerator
38
        # This is to enable data loading in the background to improve training performance
39
40
41
42
        return BackgroundGenerator(super().__iter__())


def get_parser(**parser_kwargs):
43
    # A function to create an ArgumentParser object and add arguments to it
Fazzie's avatar
Fazzie committed
44

45
    def str2bool(v):
NatalieC323's avatar
NatalieC323 committed
46
        # A helper function to parse boolean values from command line arguments
47
48
49
50
51
52
53
54
        if isinstance(v, bool):
            return v
        if v.lower() in ("yes", "true", "t", "y", "1"):
            return True
        elif v.lower() in ("no", "false", "f", "n", "0"):
            return False
        else:
            raise argparse.ArgumentTypeError("Boolean value expected.")
55

NatalieC323's avatar
NatalieC323 committed
56
    # Create an ArgumentParser object with specifies kwargs
57
    parser = argparse.ArgumentParser(**parser_kwargs)
NatalieC323's avatar
NatalieC323 committed
58

59
    # Add various command line arguments with their default values and descriptions
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    parser.add_argument(
        "-n",
        "--name",
        type=str,
        const=True,
        default="",
        nargs="?",
        help="postfix for logdir",
    )
    parser.add_argument(
        "-r",
        "--resume",
        type=str,
        const=True,
        default="",
        nargs="?",
        help="resume from logdir or checkpoint in logdir",
    )
    parser.add_argument(
        "-b",
        "--base",
        nargs="*",
        metavar="base_config.yaml",
        help="paths to base configs. Loaded from left-to-right. "
Fazzie's avatar
Fazzie committed
84
        "Parameters can be overwritten or added with command-line options of the form `--key value`.",
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        default=list(),
    )
    parser.add_argument(
        "-t",
        "--train",
        type=str2bool,
        const=True,
        default=False,
        nargs="?",
        help="train",
    )
    parser.add_argument(
        "--no-test",
        type=str2bool,
        const=True,
        default=False,
        nargs="?",
        help="disable test",
    )
Fazzie's avatar
Fazzie committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    parser.add_argument(
        "-p",
        "--project",
        help="name of new or path to existing project",
    )
    parser.add_argument(
        "-c",
        "--ckpt",
        type=str,
        const=True,
        default="",
        nargs="?",
        help="load pretrained checkpoint from stable AI",
    )
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    parser.add_argument(
        "-d",
        "--debug",
        type=str2bool,
        nargs="?",
        const=True,
        default=False,
        help="enable post-mortem debugging",
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        default=23,
        help="seed for seed_everything",
    )
    parser.add_argument(
        "-f",
        "--postfix",
        type=str,
        default="",
        help="post-postfix for default name",
    )
    parser.add_argument(
        "-l",
        "--logdir",
        type=str,
        default="logs",
        help="directory for logging dat shit",
    )
    parser.add_argument(
        "--scale_lr",
        type=str2bool,
        nargs="?",
        const=True,
        default=True,
        help="scale base-lr by ngpu * batch_size * n_accumulate",
    )
Fazzie's avatar
Fazzie committed
156

157
158
    return parser

159

NatalieC323's avatar
NatalieC323 committed
160
# A function that returns the non-default arguments between two objects
161
def nondefault_trainer_args(opt):
162
    # create an argument parser
163
    parser = argparse.ArgumentParser()
NatalieC323's avatar
NatalieC323 committed
164
    # add pytorch lightning trainer default arguments
165
    parser = Trainer.add_argparse_args(parser)
NatalieC323's avatar
NatalieC323 committed
166
    # parse the empty arguments to obtain the default values
167
    args = parser.parse_args([])
NatalieC323's avatar
NatalieC323 committed
168
    # return all non-default arguments
169
170
    return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))

171

NatalieC323's avatar
NatalieC323 committed
172
# A dataset wrapper class to create a pytorch dataset from an arbitrary object
173
174
175
176
177
178
179
180
181
182
183
184
class WrappedDataset(Dataset):
    """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""

    def __init__(self, dataset):
        self.data = dataset

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

185

NatalieC323's avatar
NatalieC323 committed
186
# A function to initialize worker processes
187
188
189
190
191
192
193
def worker_init_fn(_):
    worker_info = torch.utils.data.get_worker_info()

    dataset = worker_info.dataset
    worker_id = worker_info.id

    if isinstance(dataset, Txt2ImgIterableBaseDataset):
194
        # divide the dataset into equal parts for each worker
195
        split_size = dataset.num_records // worker_info.num_workers
196
        # set the sample IDs for the current worker
197
        # reset num_records to the true number to retain reliable length information
198
        dataset.sample_ids = dataset.valid_ids[worker_id * split_size : (worker_id + 1) * split_size]
NatalieC323's avatar
NatalieC323 committed
199
        # set the seed for the current worker
200
201
202
203
204
        current_id = np.random.choice(len(np.random.get_state()[1]), 1)
        return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
    else:
        return np.random.seed(np.random.get_state()[1][0] + worker_id)

Fazzie's avatar
Fazzie committed
205

206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# Provide functionality for creating data loaders based on provided dataset configurations
class DataModuleFromConfig(pl.LightningDataModule):
    def __init__(
        self,
        batch_size,
        train=None,
        validation=None,
        test=None,
        predict=None,
        wrap=False,
        num_workers=None,
        shuffle_test_loader=False,
        use_worker_init_fn=False,
        shuffle_val_dataloader=False,
    ):
221
        super().__init__()
NatalieC323's avatar
NatalieC323 committed
222
        # Set data module attributes
223
224
225
226
        self.batch_size = batch_size
        self.dataset_configs = dict()
        self.num_workers = num_workers if num_workers is not None else batch_size * 2
        self.use_worker_init_fn = use_worker_init_fn
NatalieC323's avatar
NatalieC323 committed
227
        # If a dataset is passed, add it to the dataset configs and create a corresponding dataloader method
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        if train is not None:
            self.dataset_configs["train"] = train
            self.train_dataloader = self._train_dataloader
        if validation is not None:
            self.dataset_configs["validation"] = validation
            self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
        if test is not None:
            self.dataset_configs["test"] = test
            self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
        if predict is not None:
            self.dataset_configs["predict"] = predict
            self.predict_dataloader = self._predict_dataloader
        self.wrap = wrap

    def prepare_data(self):
NatalieC323's avatar
NatalieC323 committed
243
        # Instantiate datasets
244
245
246
247
        for data_cfg in self.dataset_configs.values():
            instantiate_from_config(data_cfg)

    def setup(self, stage=None):
NatalieC323's avatar
NatalieC323 committed
248
        # Instantiate datasets from the dataset configs
Fazzie's avatar
Fazzie committed
249
        self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
250

NatalieC323's avatar
NatalieC323 committed
251
        # If wrap is true, create a WrappedDataset for each dataset
252
253
254
255
256
        if self.wrap:
            for k in self.datasets:
                self.datasets[k] = WrappedDataset(self.datasets[k])

    def _train_dataloader(self):
257
258
259
        # Check if the train dataset is iterable
        is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset)
        # Set the worker initialization function of the dataset is iterable or use_worker_init_fn is True
260
261
262
263
        if is_iterable_dataset or self.use_worker_init_fn:
            init_fn = worker_init_fn
        else:
            init_fn = None
NatalieC323's avatar
NatalieC323 committed
264
        # Return a DataLoaderX object for the train dataset
265
266
267
268
269
270
271
        return DataLoaderX(
            self.datasets["train"],
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False if is_iterable_dataset else True,
            worker_init_fn=init_fn,
        )
272
273

    def _val_dataloader(self, shuffle=False):
274
275
        # Check if the validation dataset is iterable
        if isinstance(self.datasets["validation"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
276
277
278
            init_fn = worker_init_fn
        else:
            init_fn = None
NatalieC323's avatar
NatalieC323 committed
279
        # Return a DataLoaderX object for the validation dataset
280
281
282
283
284
285
286
        return DataLoaderX(
            self.datasets["validation"],
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            worker_init_fn=init_fn,
            shuffle=shuffle,
        )
287
288

    def _test_dataloader(self, shuffle=False):
NatalieC323's avatar
NatalieC323 committed
289
        # Check if the test dataset is iterable
290
        is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset)
NatalieC323's avatar
NatalieC323 committed
291
        # Set the worker initialization function if the dataset is iterable or use_worker_init_fn is True
292
293
294
295
296
297
298
299
        if is_iterable_dataset or self.use_worker_init_fn:
            init_fn = worker_init_fn
        else:
            init_fn = None

        # do not shuffle dataloader for iterable dataset
        shuffle = shuffle and (not is_iterable_dataset)

300
301
302
303
304
305
306
        return DataLoaderX(
            self.datasets["test"],
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            worker_init_fn=init_fn,
            shuffle=shuffle,
        )
307
308

    def _predict_dataloader(self, shuffle=False):
309
        if isinstance(self.datasets["predict"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
310
311
312
            init_fn = worker_init_fn
        else:
            init_fn = None
313
314
315
        return DataLoaderX(
            self.datasets["predict"], batch_size=self.batch_size, num_workers=self.num_workers, worker_init_fn=init_fn
        )
316
317
318


class SetupCallback(Callback):
319
    # Initialize the callback with the necessary parameters
Fazzie's avatar
Fazzie committed
320

321
322
323
324
325
326
327
328
329
330
    def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
        super().__init__()
        self.resume = resume
        self.now = now
        self.logdir = logdir
        self.ckptdir = ckptdir
        self.cfgdir = cfgdir
        self.config = config
        self.lightning_config = lightning_config

NatalieC323's avatar
NatalieC323 committed
331
    # Save a checkpoint if training is interrupted with keyboard interrupt
332
333
334
335
336
337
    def on_keyboard_interrupt(self, trainer, pl_module):
        if trainer.global_rank == 0:
            print("Summoning checkpoint.")
            ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
            trainer.save_checkpoint(ckpt_path)

NatalieC323's avatar
NatalieC323 committed
338
    # Create necessary directories and save configuration files before training starts
339
340
341
342
343
344
345
346
    # def on_pretrain_routine_start(self, trainer, pl_module):
    def on_fit_start(self, trainer, pl_module):
        if trainer.global_rank == 0:
            # Create logdirs and save configs
            os.makedirs(self.logdir, exist_ok=True)
            os.makedirs(self.ckptdir, exist_ok=True)
            os.makedirs(self.cfgdir, exist_ok=True)

347
            # Create trainstep checkpoint directory if necessary
348
            if "callbacks" in self.lightning_config:
349
350
                if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]:
                    os.makedirs(os.path.join(self.ckptdir, "trainstep_checkpoints"), exist_ok=True)
351
352
            print("Project config")
            print(OmegaConf.to_yaml(self.config))
Fazzie's avatar
Fazzie committed
353
            OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
354

NatalieC323's avatar
NatalieC323 committed
355
            # Save project config and lightning config as YAML files
356
357
            print("Lightning config")
            print(OmegaConf.to_yaml(self.lightning_config))
358
359
360
361
            OmegaConf.save(
                OmegaConf.create({"lightning": self.lightning_config}),
                os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
            )
362

NatalieC323's avatar
NatalieC323 committed
363
        # Remove log directory if resuming training and directory already exists
364
365
366
367
368
369
370
371
372
373
374
        else:
            # ModelCheckpoint callback created log directory --- remove it
            if not self.resume and os.path.exists(self.logdir):
                dst, name = os.path.split(self.logdir)
                dst = os.path.join(dst, "child_runs", name)
                os.makedirs(os.path.split(dst)[0], exist_ok=True)
                try:
                    os.rename(self.logdir, dst)
                except FileNotFoundError:
                    pass

Fazzie's avatar
Fazzie committed
375
376
377
378
379
380
    # def on_fit_end(self, trainer, pl_module):
    #     if trainer.global_rank == 0:
    #         ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
    #         rank_zero_info(f"Saving final checkpoint in {ckpt_path}.")
    #         trainer.save_checkpoint(ckpt_path)

381

382
# PyTorch Lightning callback for logging images during training and validation of a deep learning model
383
class ImageLogger(Callback):
384
385
386
387
388
389
390
391
392
393
394
395
    def __init__(
        self,
        batch_frequency,  # Frequency of batches on which to log images
        max_images,  # Maximum number of images to log
        clamp=True,  # Whether to clamp pixel values to [-1,1]
        increase_log_steps=True,  # Whether to increase frequency of log steps exponentially
        rescale=True,  # Whether to rescale pixel values to [0,1]
        disabled=False,  # Whether to disable logging
        log_on_batch_idx=False,  # Whether to log on batch index instead of global step
        log_first_step=False,  # Whether to log on the first step
        log_images_kwargs=None,
    ):  # Additional keyword arguments to pass to log_images method
396
397
398
399
400
        super().__init__()
        self.rescale = rescale
        self.batch_freq = batch_frequency
        self.max_images = max_images
        self.logger_log_images = {
NatalieC323's avatar
NatalieC323 committed
401
            # Dictionary of logger classes and their corresponding logging methods
402
            pl.loggers.CSVLogger: self._testtube,
403
        }
NatalieC323's avatar
NatalieC323 committed
404
        # Create a list of exponentially increasing log steps, starting from 1 and ending at batch_frequency
Fazzie's avatar
Fazzie committed
405
        self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
406
407
408
409
410
411
412
413
        if not increase_log_steps:
            self.log_steps = [self.batch_freq]
        self.clamp = clamp
        self.disabled = disabled
        self.log_on_batch_idx = log_on_batch_idx
        self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
        self.log_first_step = log_first_step

414
415
416
417
418
419
420
421
422
    @rank_zero_only  # Ensure that only the first process in distributed training executes this method
    def _testtube(
        self,  # The PyTorch Lightning module
        pl_module,  # A dictionary of images to log.
        images,  #
        batch_idx,  # The batch index.
        split,  # The split (train/val) on which to log the images
    ):
        # Method for logging images using test-tube logger
423
424
        for k in images:
            grid = torchvision.utils.make_grid(images[k])
425
            grid = (grid + 1.0) / 2.0  # -1,1 -> 0,1; c,h,w
426
427

            tag = f"{split}/{k}"
NatalieC323's avatar
NatalieC323 committed
428
            # Add image grid to logger's experiment
Fazzie's avatar
Fazzie committed
429
            pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step)
430
431

    @rank_zero_only
432
433
434
435
436
437
438
439
440
441
    def log_local(
        self,
        save_dir,
        split,  # The split (train/val) on which to log the images
        images,  # A dictionary of images to log
        global_step,  # The global step
        current_epoch,  # The current epoch.
        batch_idx,
    ):
        # Method for saving image grids to local file system
442
443
444
445
        root = os.path.join(save_dir, "images", split)
        for k in images:
            grid = torchvision.utils.make_grid(images[k], nrow=4)
            if self.rescale:
446
                grid = (grid + 1.0) / 2.0  # -1,1 -> 0,1; c,h,w
447
448
449
            grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
            grid = grid.numpy()
            grid = (grid * 255).astype(np.uint8)
Fazzie's avatar
Fazzie committed
450
            filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
451
452
            path = os.path.join(root, filename)
            os.makedirs(os.path.split(path)[0], exist_ok=True)
NatalieC323's avatar
NatalieC323 committed
453
            # Save image grid as PNG file
454
455
456
            Image.fromarray(grid).save(path)

    def log_img(self, pl_module, batch, batch_idx, split="train"):
457
        # Function for logging images to both the logger and local file system.
458
        check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
NatalieC323's avatar
NatalieC323 committed
459
        # check if it's time to log an image batch
460
461
462
463
464
465
        if (
            self.check_frequency(check_idx)
            and hasattr(pl_module, "log_images")  # batch_idx % self.batch_freq == 0
            and callable(pl_module.log_images)
            and self.max_images > 0
        ):
NatalieC323's avatar
NatalieC323 committed
466
            # Get logger type and check if training mode is on
467
468
469
470
471
472
473
            logger = type(pl_module.logger)

            is_train = pl_module.training
            if is_train:
                pl_module.eval()

            with torch.no_grad():
NatalieC323's avatar
NatalieC323 committed
474
                # Get images from log_images method of the pl_module
475
476
                images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)

NatalieC323's avatar
NatalieC323 committed
477
            # Clip images if specified and convert to CPU tensor
478
479
480
481
482
483
            for k in images:
                N = min(images[k].shape[0], self.max_images)
                images[k] = images[k][:N]
                if isinstance(images[k], torch.Tensor):
                    images[k] = images[k].detach().cpu()
                    if self.clamp:
484
                        images[k] = torch.clamp(images[k], -1.0, 1.0)
485

NatalieC323's avatar
NatalieC323 committed
486
            # Log images locally to file system
487
488
489
            self.log_local(
                pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, batch_idx
            )
490

NatalieC323's avatar
NatalieC323 committed
491
            # log the images using the logger
492
493
494
            logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
            logger_log_images(pl_module, images, pl_module.global_step, split)

NatalieC323's avatar
NatalieC323 committed
495
            # switch back to training mode if necessary
496
497
498
            if is_train:
                pl_module.train()

NatalieC323's avatar
NatalieC323 committed
499
    # The function checks if it's time to log an image batch
500
    def check_frequency(self, check_idx):
501
502
503
        if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
            check_idx > 0 or self.log_first_step
        ):
504
505
506
507
508
509
510
            try:
                self.log_steps.pop(0)
            except IndexError as e:
                print(e)
            return True
        return False

NatalieC323's avatar
NatalieC323 committed
511
    # Log images on train batch end if logging is not disabled
512
513
514
515
516
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        # if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
        #     self.log_img(pl_module, batch, batch_idx, split="train")
        pass

NatalieC323's avatar
NatalieC323 committed
517
    # Log images on validation batch end if logging is not disabled and in validation mode
518
519
520
    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if not self.disabled and pl_module.global_step > 0:
            self.log_img(pl_module, batch, batch_idx, split="val")
NatalieC323's avatar
NatalieC323 committed
521
        # log gradients during calibration if necessary
522
        if hasattr(pl_module, "calibrate_grad_norm"):
523
524
525
526
527
528
529
530
531
532
            if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
                self.log_gradients(trainer, pl_module, batch_idx=batch_idx)


class CUDACallback(Callback):
    # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py

    def on_train_start(self, trainer, pl_module):
        rank_zero_info("Training is starting")

533
    # the method is called at the end of each training epoch
534
535
536
537
538
539
540
541
542
543
544
    def on_train_end(self, trainer, pl_module):
        rank_zero_info("Training is ending")

    def on_train_epoch_start(self, trainer, pl_module):
        # Reset the memory use counter
        torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
        torch.cuda.synchronize(trainer.strategy.root_device.index)
        self.start_time = time.time()

    def on_train_epoch_end(self, trainer, pl_module):
        torch.cuda.synchronize(trainer.strategy.root_device.index)
Fazzie's avatar
Fazzie committed
545
        max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2**20
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
        epoch_time = time.time() - self.start_time

        try:
            max_memory = trainer.strategy.reduce(max_memory)
            epoch_time = trainer.strategy.reduce(epoch_time)

            rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
            rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
        except AttributeError:
            pass


if __name__ == "__main__":
    # custom parser to specify config files, train, test and debug mode,
    # postfix, resume.
    # `--key value` arguments are interpreted as arguments to the trainer.
    # `nested.key=value` arguments are interpreted as config parameters.
    # configs are merged from left-to-right followed by command line parameters.

    # model:
    #   base_learning_rate: float
    #   target: path to lightning module
    #   params:
    #       key: value
    # data:
    #   target: main.DataModuleFromConfig
    #   params:
    #      batch_size: int
    #      wrap: bool
    #      train:
    #          target: path to train dataset
    #          params:
    #              key: value
    #      validation:
    #          target: path to validation dataset
    #          params:
    #              key: value
    #      test:
    #          target: path to test dataset
    #          params:
    #              key: value
587
    # lightning: (optional, has sane defaults and can be specified on cmdline)
588
589
590
591
592
593
594
595
596
597
598
599
    #   trainer:
    #       additional arguments to trainer
    #   logger:
    #       logger to instantiate
    #   modelcheckpoint:
    #       modelcheckpoint to instantiate
    #   callbacks:
    #       callback1:
    #           target: importpath
    #           params:
    #               key: value

NatalieC323's avatar
NatalieC323 committed
600
    # get the current time to create a new logging directory
601
602
603
604
605
606
607
608
609
610
611
    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")

    # add cwd for convenience and to make classes in this file available when
    # running as `python main.py`
    # (in particular `main.DataModuleFromConfig`)
    sys.path.append(os.getcwd())

    parser = get_parser()
    parser = Trainer.add_argparse_args(parser)

    opt, unknown = parser.parse_known_args()
612
    # Verify the arguments are both specified
613
    if opt.name and opt.resume:
614
615
616
617
618
        raise ValueError(
            "-n/--name and -r/--resume cannot be specified both."
            "If you want to resume training in a new log folder, "
            "use -n/--name in combination with --resume_from_checkpoint"
        )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
619

NatalieC323's avatar
NatalieC323 committed
620
    # Check if the "resume" option is specified, resume training from the checkpoint if it is true
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
621
    ckpt = None
622
    if opt.resume:
Fazzie's avatar
Fazzie committed
623
        rank_zero_info("Resuming from {}".format(opt.resume))
624
625
626
627
628
629
630
        if not os.path.exists(opt.resume):
            raise ValueError("Cannot find {}".format(opt.resume))
        if os.path.isfile(opt.resume):
            paths = opt.resume.split("/")
            # idx = len(paths)-paths[::-1].index("logs")+1
            # logdir = "/".join(paths[:idx])
            logdir = "/".join(paths[:-2])
Fazzie's avatar
Fazzie committed
631
            rank_zero_info("logdir: {}".format(logdir))
632
633
634
635
636
637
            ckpt = opt.resume
        else:
            assert os.path.isdir(opt.resume), opt.resume
            logdir = opt.resume.rstrip("/")
            ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")

NatalieC323's avatar
NatalieC323 committed
638
        # Finds all ".yaml" configuration files in the log directory and adds them to the list of base configurations
639
640
        base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
        opt.base = base_configs + opt.base
NatalieC323's avatar
NatalieC323 committed
641
        # Gets the name of the current log directory by splitting the path and taking the last element.
642
643
644
645
646
647
        _tmp = logdir.split("/")
        nowname = _tmp[-1]
    else:
        if opt.name:
            name = "_" + opt.name
        elif opt.base:
Fazzie's avatar
Fazzie committed
648
            rank_zero_info("Using base config {}".format(opt.base))
649
650
651
652
653
654
655
656
            cfg_fname = os.path.split(opt.base[0])[-1]
            cfg_name = os.path.splitext(cfg_fname)[0]
            name = "_" + cfg_name
        else:
            name = ""
        nowname = now + name + opt.postfix
        logdir = os.path.join(opt.logdir, nowname)

NatalieC323's avatar
NatalieC323 committed
657
        # Sets the checkpoint path of the 'ckpt' option is specified
Fazzie's avatar
Fazzie committed
658
659
660
        if opt.ckpt:
            ckpt = opt.ckpt

NatalieC323's avatar
NatalieC323 committed
661
    # Create the checkpoint and configuration directories within the log directory.
662
663
    ckptdir = os.path.join(logdir, "checkpoints")
    cfgdir = os.path.join(logdir, "configs")
NatalieC323's avatar
NatalieC323 committed
664
    # Sets the seed for the random number generator to ensure reproducibility
665
666
    seed_everything(opt.seed)

667
    # Initialize and save configuration using teh OmegaConf library.
668
669
670
671
672
673
674
675
    try:
        # init and save configs
        configs = [OmegaConf.load(cfg) for cfg in opt.base]
        cli = OmegaConf.from_dotlist(unknown)
        config = OmegaConf.merge(*configs, cli)
        lightning_config = config.pop("lightning", OmegaConf.create())
        # merge trainer cli with config
        trainer_config = lightning_config.get("trainer", OmegaConf.create())
Fazzie's avatar
Fazzie committed
676

677
678
679
        for k in nondefault_trainer_args(opt):
            trainer_config[k] = getattr(opt, k)

NatalieC323's avatar
NatalieC323 committed
680
        # Check whether the accelerator is gpu
681
682
683
684
685
686
687
688
689
690
691
692
693
694
        if not trainer_config["accelerator"] == "gpu":
            del trainer_config["accelerator"]
            cpu = True
        else:
            cpu = False
        trainer_opt = argparse.Namespace(**trainer_config)
        lightning_config.trainer = trainer_config

        # model
        use_fp16 = trainer_config.get("precision", 32) == 16
        if use_fp16:
            config.model["params"].update({"use_fp16": True})
        else:
            config.model["params"].update({"use_fp16": False})
Fazzie's avatar
Fazzie committed
695
696

        if ckpt is not None:
697
            # If a checkpoint path is specified in the ckpt variable, the code updates the "ckpt" key in the "params" dictionary of the config.model configuration with the value of ckpt
Fazzie's avatar
Fazzie committed
698
699
            config.model["params"].update({"ckpt": ckpt})
            rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"]))
Fazzie's avatar
Fazzie committed
700

natalie_cao's avatar
natalie_cao committed
701
        model = LatentDiffusion(**config.model.get("params", dict()))
702
703
704
705
        # trainer and callbacks
        trainer_kwargs = dict()

        # config the logger
NatalieC323's avatar
NatalieC323 committed
706
        # Default logger configs to  log training metrics during the training process.
707
708
        default_logger_cfgs = {
            "wandb": {
709
710
711
712
713
714
                "name": nowname,
                "save_dir": logdir,
                "offline": opt.debug,
                "id": nowname,
            },
            "tensorboard": {"save_dir": logdir, "name": "diff_tb", "log_graph": True},
715
716
        }

NatalieC323's avatar
NatalieC323 committed
717
        # Set up the logger for TensorBoard
718
719
720
        default_logger_cfg = default_logger_cfgs["tensorboard"]
        if "logger" in lightning_config:
            logger_cfg = lightning_config.logger
natalie_cao's avatar
natalie_cao committed
721
            trainer_kwargs["logger"] = WandbLogger(**logger_cfg)
722
723
        else:
            logger_cfg = default_logger_cfg
natalie_cao's avatar
natalie_cao committed
724
            trainer_kwargs["logger"] = TensorBoardLogger(**logger_cfg)
725
726
727
728

        # config the strategy, defualt is ddp
        if "strategy" in trainer_config:
            strategy_cfg = trainer_config["strategy"]
natalie_cao's avatar
natalie_cao committed
729
            trainer_kwargs["strategy"] = ColossalAIStrategy(**strategy_cfg)
730
        else:
natalie_cao's avatar
natalie_cao committed
731
732
            strategy_cfg = {"find_unused_parameters": False}
            trainer_kwargs["strategy"] = DDPStrategy(**strategy_cfg)
733

NatalieC323's avatar
NatalieC323 committed
734
        # Set up ModelCheckpoint callback to save best models
735
736
737
        # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
        # specify which metric is used to determine best models
        default_modelckpt_cfg = {
738
739
740
741
742
            "dirpath": ckptdir,
            "filename": "{epoch:06}",
            "verbose": True,
            "save_last": True,
        }
743
        if hasattr(model, "monitor"):
natalie_cao's avatar
natalie_cao committed
744
745
            default_modelckpt_cfg["monitor"] = model.monitor
            default_modelckpt_cfg["save_top_k"] = 3
746
747

        if "modelcheckpoint" in lightning_config:
natalie_cao's avatar
natalie_cao committed
748
            modelckpt_cfg = lightning_config.modelcheckpoint["params"]
749
        else:
Fazzie's avatar
Fazzie committed
750
            modelckpt_cfg = OmegaConf.create()
751
        modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
752
        if version.parse(pl.__version__) < version.parse("1.4.0"):
natalie_cao's avatar
natalie_cao committed
753
754
            trainer_kwargs["checkpoint_callback"] = ModelCheckpoint(**modelckpt_cfg)

755
        # Create an empty OmegaConf configuration object
natalie_cao's avatar
natalie_cao committed
756
757

        callbacks_cfg = OmegaConf.create()
758
759

        # Instantiate items according to the configs
natalie_cao's avatar
natalie_cao committed
760
761
        trainer_kwargs.setdefault("callbacks", [])
        setup_callback_config = {
762
763
764
765
766
767
768
769
            "resume": opt.resume,  # resume training if applicable
            "now": now,
            "logdir": logdir,  # directory to save the log file
            "ckptdir": ckptdir,  # directory to save the checkpoint file
            "cfgdir": cfgdir,  # directory to save the configuration file
            "config": config,  # configuration dictionary
            "lightning_config": lightning_config,  # LightningModule configuration
        }
natalie_cao's avatar
natalie_cao committed
770
        trainer_kwargs["callbacks"].append(SetupCallback(**setup_callback_config))
771

natalie_cao's avatar
natalie_cao committed
772
        image_logger_config = {
773
774
775
776
            "batch_frequency": 750,  # how frequently to log images
            "max_images": 4,  # maximum number of images to log
            "clamp": True,  # whether to clamp pixel values to [0,1]
        }
natalie_cao's avatar
natalie_cao committed
777
        trainer_kwargs["callbacks"].append(ImageLogger(**image_logger_config))
778

natalie_cao's avatar
natalie_cao committed
779
        learning_rate_logger_config = {
780
781
782
            "logging_interval": "step",  # logging frequency (either 'step' or 'epoch')
            # "log_momentum": True                            # whether to log momentum (currently commented out)
        }
natalie_cao's avatar
natalie_cao committed
783
        trainer_kwargs["callbacks"].append(LearningRateMonitor(**learning_rate_logger_config))
784
785
786

        metrics_over_trainsteps_checkpoint_config = {
            "dirpath": os.path.join(ckptdir, "trainstep_checkpoints"),
natalie_cao's avatar
natalie_cao committed
787
788
            "filename": "{epoch:06}-{step:09}",
            "verbose": True,
789
790
791
792
            "save_top_k": -1,
            "every_n_train_steps": 10000,
            "save_weights_only": True,
        }
natalie_cao's avatar
natalie_cao committed
793
794
        trainer_kwargs["callbacks"].append(ModelCheckpoint(**metrics_over_trainsteps_checkpoint_config))
        trainer_kwargs["callbacks"].append(CUDACallback())
795

796
        # Create a Trainer object with the specified command-line arguments and keyword arguments, and set the log directory
797
        trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
Fazzie's avatar
Fazzie committed
798
        trainer.logdir = logdir
799

NatalieC323's avatar
NatalieC323 committed
800
        # Create a data module based on the configuration file
natalie_cao's avatar
natalie_cao committed
801
802
        data = DataModuleFromConfig(**config.data)

803
804
805
806
807
        # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
        # calling these ourselves should not be necessary but it is.
        # lightning still takes care of proper multiprocessing though
        data.prepare_data()
        data.setup()
Fazzie's avatar
Fazzie committed
808

NatalieC323's avatar
NatalieC323 committed
809
        # Print some information about the datasets in the data module
810
        for k in data.datasets:
Fazzie's avatar
Fazzie committed
811
            rank_zero_info(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
812

NatalieC323's avatar
NatalieC323 committed
813
814
        # Configure learning rate based on the batch size, base learning rate and number of GPUs
        # If scale_lr is true, calculate the learning rate based on additional factors
natalie_cao's avatar
natalie_cao committed
815
        bs, base_lr = config.data.batch_size, config.model.base_learning_rate
816
817
818
819
        if not cpu:
            ngpu = trainer_config["devices"]
        else:
            ngpu = 1
820
        if "accumulate_grad_batches" in lightning_config.trainer:
821
822
823
            accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
        else:
            accumulate_grad_batches = 1
Fazzie's avatar
Fazzie committed
824
        rank_zero_info(f"accumulate_grad_batches = {accumulate_grad_batches}")
825
826
827
        lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
        if opt.scale_lr:
            model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
Fazzie's avatar
Fazzie committed
828
            rank_zero_info(
829
830
831
832
                "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
                    model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr
                )
            )
833
834
        else:
            model.learning_rate = base_lr
Fazzie's avatar
Fazzie committed
835
836
            rank_zero_info("++++ NOT USING LR SCALING ++++")
            rank_zero_info(f"Setting learning rate to {model.learning_rate:.2e}")
837

NatalieC323's avatar
NatalieC323 committed
838
        # Allow checkpointing via USR1
839
840
841
842
843
844
845
846
847
        def melk(*args, **kwargs):
            # run all checkpoint hooks
            if trainer.global_rank == 0:
                print("Summoning checkpoint.")
                ckpt_path = os.path.join(ckptdir, "last.ckpt")
                trainer.save_checkpoint(ckpt_path)

        def divein(*args, **kwargs):
            if trainer.global_rank == 0:
Fazzie's avatar
Fazzie committed
848
                import pudb
849

850
851
852
                pudb.set_trace()

        import signal
853

NatalieC323's avatar
NatalieC323 committed
854
        # Assign melk to SIGUSR1 signal and divein to SIGUSR2 signal
855
856
857
        signal.signal(signal.SIGUSR1, melk)
        signal.signal(signal.SIGUSR2, divein)

NatalieC323's avatar
NatalieC323 committed
858
        # Run the training and validation
859
860
861
862
863
864
        if opt.train:
            try:
                trainer.fit(model, data)
            except Exception:
                melk()
                raise
NatalieC323's avatar
NatalieC323 committed
865
866
        # Print the maximum GPU memory allocated during training
        print(f"GPU memory usage: {torch.cuda.max_memory_allocated() / 1024**2:.0f} MB")
867
868
869
        # if not opt.no_test and not trainer.interrupted:
        #     trainer.test(model, data)
    except Exception:
NatalieC323's avatar
NatalieC323 committed
870
        # If there's an exception, debug it if opt.debug is true and the trainer's global rank is 0
871
872
873
874
875
876
877
878
        if opt.debug and trainer.global_rank == 0:
            try:
                import pudb as debugger
            except ImportError:
                import pdb as debugger
            debugger.post_mortem()
        raise
    finally:
NatalieC323's avatar
NatalieC323 committed
879
        #  Move the log directory to debug_runs if opt.debug is true and the trainer's global
880
881
882
883
884
885
886
        if opt.debug and not opt.resume and trainer.global_rank == 0:
            dst, name = os.path.split(logdir)
            dst = os.path.join(dst, "debug_runs", name)
            os.makedirs(os.path.split(dst)[0], exist_ok=True)
            os.rename(logdir, dst)
        if trainer.global_rank == 0:
            print(trainer.profiler.summary())