main.py 37.5 KB
Newer Older
Fazzie's avatar
Fazzie committed
1
2
3
4
5
6
7
import argparse
import csv
import datetime
import glob
import importlib
import os
import sys
8
import time
Fazzie's avatar
Fazzie committed
9
10

import numpy as np
11
12
13
import torch
import torchvision

Fazzie's avatar
Fazzie committed
14
15
16
17
18
try:
    import lightning.pytorch as pl
except:
    import pytorch_lightning as pl

19
from functools import partial
Fazzie's avatar
Fazzie committed
20
21
22

from omegaconf import OmegaConf
from packaging import version
23
24
from PIL import Image
from prefetch_generator import BackgroundGenerator
Fazzie's avatar
Fazzie committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from torch.utils.data import DataLoader, Dataset, Subset, random_split

try:
    from lightning.pytorch import seed_everything
    from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
    from lightning.pytorch.trainer import Trainer
    from lightning.pytorch.utilities import rank_zero_info, rank_zero_only
    LIGHTNING_PACK_NAME = "lightning.pytorch."
except:
    from pytorch_lightning import seed_everything
    from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
    from pytorch_lightning.trainer import Trainer
    from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
    LIGHTNING_PACK_NAME = "pytorch_lightning."
39
40
41

from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config
Fazzie's avatar
Fazzie committed
42
43
44

# from ldm.modules.attention import enable_flash_attentions

45
46

class DataLoaderX(DataLoader):
NatalieC323's avatar
NatalieC323 committed
47
# A custom data loader class that inherits from DataLoader
48
    def __iter__(self):
NatalieC323's avatar
NatalieC323 committed
49
50
        # Overriding the __iter__ method of DataLoader to return a BackgroundGenerator
        #This is to enable data laoding in the background to improve training performance
51
52
53
54
        return BackgroundGenerator(super().__iter__())


def get_parser(**parser_kwargs):
NatalieC323's avatar
NatalieC323 committed
55
    #A function to create an ArgumentParser object and add arguments to it
Fazzie's avatar
Fazzie committed
56

57
    def str2bool(v):
NatalieC323's avatar
NatalieC323 committed
58
        # A helper function to parse boolean values from command line arguments
59
60
61
62
63
64
65
66
        if isinstance(v, bool):
            return v
        if v.lower() in ("yes", "true", "t", "y", "1"):
            return True
        elif v.lower() in ("no", "false", "f", "n", "0"):
            return False
        else:
            raise argparse.ArgumentTypeError("Boolean value expected.")
NatalieC323's avatar
NatalieC323 committed
67
    # Create an ArgumentParser object with specifies kwargs
68
    parser = argparse.ArgumentParser(**parser_kwargs)
NatalieC323's avatar
NatalieC323 committed
69
70

    # Add vairous command line arguments with their default balues and descriptions
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    parser.add_argument(
        "-n",
        "--name",
        type=str,
        const=True,
        default="",
        nargs="?",
        help="postfix for logdir",
    )
    parser.add_argument(
        "-r",
        "--resume",
        type=str,
        const=True,
        default="",
        nargs="?",
        help="resume from logdir or checkpoint in logdir",
    )
    parser.add_argument(
        "-b",
        "--base",
        nargs="*",
        metavar="base_config.yaml",
        help="paths to base configs. Loaded from left-to-right. "
Fazzie's avatar
Fazzie committed
95
        "Parameters can be overwritten or added with command-line options of the form `--key value`.",
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        default=list(),
    )
    parser.add_argument(
        "-t",
        "--train",
        type=str2bool,
        const=True,
        default=False,
        nargs="?",
        help="train",
    )
    parser.add_argument(
        "--no-test",
        type=str2bool,
        const=True,
        default=False,
        nargs="?",
        help="disable test",
    )
Fazzie's avatar
Fazzie committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    parser.add_argument(
        "-p",
        "--project",
        help="name of new or path to existing project",
    )
    parser.add_argument(
        "-c",
        "--ckpt",
        type=str,
        const=True,
        default="",
        nargs="?",
        help="load pretrained checkpoint from stable AI",
    )
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    parser.add_argument(
        "-d",
        "--debug",
        type=str2bool,
        nargs="?",
        const=True,
        default=False,
        help="enable post-mortem debugging",
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        default=23,
        help="seed for seed_everything",
    )
    parser.add_argument(
        "-f",
        "--postfix",
        type=str,
        default="",
        help="post-postfix for default name",
    )
    parser.add_argument(
        "-l",
        "--logdir",
        type=str,
        default="logs",
        help="directory for logging dat shit",
    )
    parser.add_argument(
        "--scale_lr",
        type=str2bool,
        nargs="?",
        const=True,
        default=True,
        help="scale base-lr by ngpu * batch_size * n_accumulate",
    )
Fazzie's avatar
Fazzie committed
167

168
169
    return parser

NatalieC323's avatar
NatalieC323 committed
170
# A function that returns the non-default arguments between two objects
171
def nondefault_trainer_args(opt):
NatalieC323's avatar
NatalieC323 committed
172
    # create an argument parsser
173
    parser = argparse.ArgumentParser()
NatalieC323's avatar
NatalieC323 committed
174
    # add pytorch lightning trainer default arguments
175
    parser = Trainer.add_argparse_args(parser)
NatalieC323's avatar
NatalieC323 committed
176
    # parse the empty arguments to obtain the default values
177
    args = parser.parse_args([])
NatalieC323's avatar
NatalieC323 committed
178
    # return all non-default arguments
179
180
    return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))

NatalieC323's avatar
NatalieC323 committed
181
# A dataset wrapper class to create a pytorch dataset from an arbitrary object
182
183
184
185
186
187
188
189
190
191
192
193
class WrappedDataset(Dataset):
    """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""

    def __init__(self, dataset):
        self.data = dataset

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

NatalieC323's avatar
NatalieC323 committed
194
# A function to initialize worker processes
195
196
197
198
199
200
201
def worker_init_fn(_):
    worker_info = torch.utils.data.get_worker_info()

    dataset = worker_info.dataset
    worker_id = worker_info.id

    if isinstance(dataset, Txt2ImgIterableBaseDataset):
NatalieC323's avatar
NatalieC323 committed
202
        #divide the dataset into equal parts for each worker
203
        split_size = dataset.num_records // worker_info.num_workers
NatalieC323's avatar
NatalieC323 committed
204
        #set the sample IDs for the current worker
205
206
        # reset num_records to the true number to retain reliable length information
        dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
NatalieC323's avatar
NatalieC323 committed
207
        # set the seed for the current worker
208
209
210
211
212
        current_id = np.random.choice(len(np.random.get_state()[1]), 1)
        return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
    else:
        return np.random.seed(np.random.get_state()[1][0] + worker_id)

NatalieC323's avatar
NatalieC323 committed
213
#Provide functionality for creating data loadedrs based on provided dataset configurations
214
class DataModuleFromConfig(pl.LightningDataModule):
Fazzie's avatar
Fazzie committed
215
216
217
218
219
220
221
222
223
224
225

    def __init__(self,
                 batch_size,
                 train=None,
                 validation=None,
                 test=None,
                 predict=None,
                 wrap=False,
                 num_workers=None,
                 shuffle_test_loader=False,
                 use_worker_init_fn=False,
226
227
                 shuffle_val_dataloader=False):
        super().__init__()
NatalieC323's avatar
NatalieC323 committed
228
        # Set data module attributes
229
230
231
232
        self.batch_size = batch_size
        self.dataset_configs = dict()
        self.num_workers = num_workers if num_workers is not None else batch_size * 2
        self.use_worker_init_fn = use_worker_init_fn
NatalieC323's avatar
NatalieC323 committed
233
        # If a dataset is passed, add it to the dataset configs and create a corresponding dataloader method
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        if train is not None:
            self.dataset_configs["train"] = train
            self.train_dataloader = self._train_dataloader
        if validation is not None:
            self.dataset_configs["validation"] = validation
            self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
        if test is not None:
            self.dataset_configs["test"] = test
            self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
        if predict is not None:
            self.dataset_configs["predict"] = predict
            self.predict_dataloader = self._predict_dataloader
        self.wrap = wrap

    def prepare_data(self):
NatalieC323's avatar
NatalieC323 committed
249
        # Instantiate datasets
250
251
252
253
        for data_cfg in self.dataset_configs.values():
            instantiate_from_config(data_cfg)

    def setup(self, stage=None):
NatalieC323's avatar
NatalieC323 committed
254
        # Instantiate datasets from the dataset configs
Fazzie's avatar
Fazzie committed
255
        self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
NatalieC323's avatar
NatalieC323 committed
256
257
        
        # If wrap is true, create a WrappedDataset for each dataset
258
259
260
261
262
        if self.wrap:
            for k in self.datasets:
                self.datasets[k] = WrappedDataset(self.datasets[k])

    def _train_dataloader(self):
NatalieC323's avatar
NatalieC323 committed
263
        #Check if the train dataset is iterable
264
        is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
NatalieC323's avatar
NatalieC323 committed
265
        #Set the worker initialization function of the dataset isiterable or use_worker_init_fn is True
266
267
268
269
        if is_iterable_dataset or self.use_worker_init_fn:
            init_fn = worker_init_fn
        else:
            init_fn = None
NatalieC323's avatar
NatalieC323 committed
270
        # Return a DataLoaderX object for the train dataset
Fazzie's avatar
Fazzie committed
271
272
273
274
275
        return DataLoaderX(self.datasets["train"],
                           batch_size=self.batch_size,
                           num_workers=self.num_workers,
                           shuffle=False if is_iterable_dataset else True,
                           worker_init_fn=init_fn)
276
277

    def _val_dataloader(self, shuffle=False):
NatalieC323's avatar
NatalieC323 committed
278
        #Check if the validation dataset is iterable
279
280
281
282
        if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
            init_fn = worker_init_fn
        else:
            init_fn = None
NatalieC323's avatar
NatalieC323 committed
283
        # Return a DataLoaderX object for the validation dataset
284
        return DataLoaderX(self.datasets["validation"],
Fazzie's avatar
Fazzie committed
285
286
287
288
                           batch_size=self.batch_size,
                           num_workers=self.num_workers,
                           worker_init_fn=init_fn,
                           shuffle=shuffle)
289
290

    def _test_dataloader(self, shuffle=False):
NatalieC323's avatar
NatalieC323 committed
291
        # Check if the test dataset is iterable
292
        is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
NatalieC323's avatar
NatalieC323 committed
293
        # Set the worker initialization function if the dataset is iterable or use_worker_init_fn is True
294
295
296
297
298
299
300
301
        if is_iterable_dataset or self.use_worker_init_fn:
            init_fn = worker_init_fn
        else:
            init_fn = None

        # do not shuffle dataloader for iterable dataset
        shuffle = shuffle and (not is_iterable_dataset)

Fazzie's avatar
Fazzie committed
302
303
304
305
306
        return DataLoaderX(self.datasets["test"],
                           batch_size=self.batch_size,
                           num_workers=self.num_workers,
                           worker_init_fn=init_fn,
                           shuffle=shuffle)
307
308
309
310
311
312

    def _predict_dataloader(self, shuffle=False):
        if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
            init_fn = worker_init_fn
        else:
            init_fn = None
Fazzie's avatar
Fazzie committed
313
314
315
316
        return DataLoaderX(self.datasets["predict"],
                           batch_size=self.batch_size,
                           num_workers=self.num_workers,
                           worker_init_fn=init_fn)
317
318
319


class SetupCallback(Callback):
NatalieC323's avatar
NatalieC323 committed
320
    # I nitialize the callback with the necessary parameters
Fazzie's avatar
Fazzie committed
321

322
323
324
325
326
327
328
329
330
331
    def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
        super().__init__()
        self.resume = resume
        self.now = now
        self.logdir = logdir
        self.ckptdir = ckptdir
        self.cfgdir = cfgdir
        self.config = config
        self.lightning_config = lightning_config

NatalieC323's avatar
NatalieC323 committed
332
    # Save a checkpoint if training is interrupted with keyboard interrupt
333
334
335
336
337
338
    def on_keyboard_interrupt(self, trainer, pl_module):
        if trainer.global_rank == 0:
            print("Summoning checkpoint.")
            ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
            trainer.save_checkpoint(ckpt_path)

NatalieC323's avatar
NatalieC323 committed
339
    # Create necessary directories and save configuration files before training starts
340
341
342
343
344
345
346
347
    # def on_pretrain_routine_start(self, trainer, pl_module):
    def on_fit_start(self, trainer, pl_module):
        if trainer.global_rank == 0:
            # Create logdirs and save configs
            os.makedirs(self.logdir, exist_ok=True)
            os.makedirs(self.ckptdir, exist_ok=True)
            os.makedirs(self.cfgdir, exist_ok=True)

NatalieC323's avatar
NatalieC323 committed
348
            #Create trainstep checkpoint directory if necessary
349
350
351
352
353
            if "callbacks" in self.lightning_config:
                if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:
                    os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
            print("Project config")
            print(OmegaConf.to_yaml(self.config))
Fazzie's avatar
Fazzie committed
354
            OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
355

NatalieC323's avatar
NatalieC323 committed
356
            # Save project config and lightning config as YAML files
357
358
359
360
361
            print("Lightning config")
            print(OmegaConf.to_yaml(self.lightning_config))
            OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
                           os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))

NatalieC323's avatar
NatalieC323 committed
362
        # Remove log directory if resuming training and directory already exists
363
364
365
366
367
368
369
370
371
372
373
        else:
            # ModelCheckpoint callback created log directory --- remove it
            if not self.resume and os.path.exists(self.logdir):
                dst, name = os.path.split(self.logdir)
                dst = os.path.join(dst, "child_runs", name)
                os.makedirs(os.path.split(dst)[0], exist_ok=True)
                try:
                    os.rename(self.logdir, dst)
                except FileNotFoundError:
                    pass

Fazzie's avatar
Fazzie committed
374
375
376
377
378
379
    # def on_fit_end(self, trainer, pl_module):
    #     if trainer.global_rank == 0:
    #         ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
    #         rank_zero_info(f"Saving final checkpoint in {ckpt_path}.")
    #         trainer.save_checkpoint(ckpt_path)

380

NatalieC323's avatar
NatalieC323 committed
381
# PyTorch Lightning callback for ogging images during training and validation of a deep learning model
382
class ImageLogger(Callback):
Fazzie's avatar
Fazzie committed
383
384

    def __init__(self,
NatalieC323's avatar
NatalieC323 committed
385
386
387
388
389
390
391
392
393
                 batch_frequency, # Frequency of batches on which to log images
                 max_images,      # Maximum number of images to log
                 clamp=True,      # Whether to clamp pixel values to [-1,1]
                 increase_log_steps=True,   # Whether to increase frequency of log steps exponentially
                 rescale=True,    # Whetehr to rescale pixel values to [0,1]
                 disabled=False,  # Whether to disable logging
                 log_on_batch_idx=False,   # Whether to log on baych index instead of global step
                 log_first_step=False,     # Whetehr to log on the first step
                 log_images_kwargs=None):  # Additional keyword arguments to pass to log_images method
394
395
396
397
398
        super().__init__()
        self.rescale = rescale
        self.batch_freq = batch_frequency
        self.max_images = max_images
        self.logger_log_images = {
NatalieC323's avatar
NatalieC323 committed
399
400
            # Dictionary of logger classes and their corresponding logging methods
            pl.loggers.CSVLogger: self._testtube,   
401
        }
NatalieC323's avatar
NatalieC323 committed
402
        # Create a list of exponentially increasing log steps, starting from 1 and ending at batch_frequency
Fazzie's avatar
Fazzie committed
403
        self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
404
405
406
407
408
409
410
411
        if not increase_log_steps:
            self.log_steps = [self.batch_freq]
        self.clamp = clamp
        self.disabled = disabled
        self.log_on_batch_idx = log_on_batch_idx
        self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
        self.log_first_step = log_first_step

NatalieC323's avatar
NatalieC323 committed
412
413
414
415
416
417
418
419
    @rank_zero_only   # Ensure that only the first process in distributed training executes this method
    def _testtube(self,         # The PyTorch Lightning module
                  pl_module,    # A dictionary of images to log.
                  images,       # 
                  batch_idx,    # The batch index.
                  split         # The split (train/val) on which to log the images
                  ):
         # Method for logging images using test-tube logger
420
421
        for k in images:
            grid = torchvision.utils.make_grid(images[k])
Fazzie's avatar
Fazzie committed
422
            grid = (grid + 1.0) / 2.0    # -1,1 -> 0,1; c,h,w
423
424

            tag = f"{split}/{k}"
NatalieC323's avatar
NatalieC323 committed
425
            # Add image grid to logger's experiment
Fazzie's avatar
Fazzie committed
426
            pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step)
427
428

    @rank_zero_only
NatalieC323's avatar
NatalieC323 committed
429
430
431
432
433
434
435
436
437
    def log_local(self,          
                  save_dir,      
                  split,         # The split (train/val) on which to log the images
                  images,        # A dictionary of images to log
                  global_step,   # The global step
                  current_epoch, # The current epoch.
                  batch_idx
                  ):
    # Method for saving image grids to local file system
438
439
440
441
        root = os.path.join(save_dir, "images", split)
        for k in images:
            grid = torchvision.utils.make_grid(images[k], nrow=4)
            if self.rescale:
Fazzie's avatar
Fazzie committed
442
                grid = (grid + 1.0) / 2.0    # -1,1 -> 0,1; c,h,w
443
444
445
            grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
            grid = grid.numpy()
            grid = (grid * 255).astype(np.uint8)
Fazzie's avatar
Fazzie committed
446
            filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
447
448
            path = os.path.join(root, filename)
            os.makedirs(os.path.split(path)[0], exist_ok=True)
NatalieC323's avatar
NatalieC323 committed
449
            # Save image grid as PNG file
450
451
452
            Image.fromarray(grid).save(path)

    def log_img(self, pl_module, batch, batch_idx, split="train"):
NatalieC323's avatar
NatalieC323 committed
453
    #Function for logging images to both the logger and local file system.
454
        check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
NatalieC323's avatar
NatalieC323 committed
455
        # check if it's time to log an image batch
Fazzie's avatar
Fazzie committed
456
457
        if (self.check_frequency(check_idx) and    # batch_idx % self.batch_freq == 0
                hasattr(pl_module, "log_images") and callable(pl_module.log_images) and self.max_images > 0):
NatalieC323's avatar
NatalieC323 committed
458
            # Get logger type and check if training mode is on
459
460
461
462
463
464
465
            logger = type(pl_module.logger)

            is_train = pl_module.training
            if is_train:
                pl_module.eval()

            with torch.no_grad():
NatalieC323's avatar
NatalieC323 committed
466
                # Get images from log_images method of the pl_module
467
468
                images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)

NatalieC323's avatar
NatalieC323 committed
469
            # Clip images if specified and convert to CPU tensor
470
471
472
473
474
475
476
477
            for k in images:
                N = min(images[k].shape[0], self.max_images)
                images[k] = images[k][:N]
                if isinstance(images[k], torch.Tensor):
                    images[k] = images[k].detach().cpu()
                    if self.clamp:
                        images[k] = torch.clamp(images[k], -1., 1.)

NatalieC323's avatar
NatalieC323 committed
478
            # Log images locally to file system
Fazzie's avatar
Fazzie committed
479
480
            self.log_local(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch,
                           batch_idx)
481

NatalieC323's avatar
NatalieC323 committed
482
            # log the images using the logger
483
484
485
            logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
            logger_log_images(pl_module, images, pl_module.global_step, split)

NatalieC323's avatar
NatalieC323 committed
486
            # switch back to training mode if necessary
487
488
489
            if is_train:
                pl_module.train()

NatalieC323's avatar
NatalieC323 committed
490
    # The function checks if it's time to log an image batch
491
    def check_frequency(self, check_idx):
Fazzie's avatar
Fazzie committed
492
493
        if ((check_idx % self.batch_freq) == 0 or
            (check_idx in self.log_steps)) and (check_idx > 0 or self.log_first_step):
494
495
496
497
498
499
500
501
            try:
                self.log_steps.pop(0)
            except IndexError as e:
                print(e)
                pass
            return True
        return False

NatalieC323's avatar
NatalieC323 committed
502
    # Log images on train batch end if logging is not disabled
503
504
505
506
507
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        # if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
        #     self.log_img(pl_module, batch, batch_idx, split="train")
        pass

NatalieC323's avatar
NatalieC323 committed
508
    # Log images on validation batch end if logging is not disabled and in validation mode
509
510
511
    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if not self.disabled and pl_module.global_step > 0:
            self.log_img(pl_module, batch, batch_idx, split="val")
NatalieC323's avatar
NatalieC323 committed
512
        # log gradients during calibration if necessary
513
514
515
516
517
518
519
520
521
522
523
        if hasattr(pl_module, 'calibrate_grad_norm'):
            if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
                self.log_gradients(trainer, pl_module, batch_idx=batch_idx)


class CUDACallback(Callback):
    # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py

    def on_train_start(self, trainer, pl_module):
        rank_zero_info("Training is starting")

NatalieC323's avatar
NatalieC323 committed
524
    #the method is called at the end of each training epoch
525
526
527
528
529
530
531
532
533
534
535
    def on_train_end(self, trainer, pl_module):
        rank_zero_info("Training is ending")

    def on_train_epoch_start(self, trainer, pl_module):
        # Reset the memory use counter
        torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
        torch.cuda.synchronize(trainer.strategy.root_device.index)
        self.start_time = time.time()

    def on_train_epoch_end(self, trainer, pl_module):
        torch.cuda.synchronize(trainer.strategy.root_device.index)
Fazzie's avatar
Fazzie committed
536
        max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2**20
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
        epoch_time = time.time() - self.start_time

        try:
            max_memory = trainer.strategy.reduce(max_memory)
            epoch_time = trainer.strategy.reduce(epoch_time)

            rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
            rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
        except AttributeError:
            pass


if __name__ == "__main__":
    # custom parser to specify config files, train, test and debug mode,
    # postfix, resume.
    # `--key value` arguments are interpreted as arguments to the trainer.
    # `nested.key=value` arguments are interpreted as config parameters.
    # configs are merged from left-to-right followed by command line parameters.

    # model:
    #   base_learning_rate: float
    #   target: path to lightning module
    #   params:
    #       key: value
    # data:
    #   target: main.DataModuleFromConfig
    #   params:
    #      batch_size: int
    #      wrap: bool
    #      train:
    #          target: path to train dataset
    #          params:
    #              key: value
    #      validation:
    #          target: path to validation dataset
    #          params:
    #              key: value
    #      test:
    #          target: path to test dataset
    #          params:
    #              key: value
    # lightning: (optional, has sane defaults and can be specified on cmdline)
    #   trainer:
    #       additional arguments to trainer
    #   logger:
    #       logger to instantiate
    #   modelcheckpoint:
    #       modelcheckpoint to instantiate
    #   callbacks:
    #       callback1:
    #           target: importpath
    #           params:
    #               key: value

NatalieC323's avatar
NatalieC323 committed
591
    # get the current time to create a new logging directory
592
593
594
595
596
597
598
599
600
601
602
    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")

    # add cwd for convenience and to make classes in this file available when
    # running as `python main.py`
    # (in particular `main.DataModuleFromConfig`)
    sys.path.append(os.getcwd())

    parser = get_parser()
    parser = Trainer.add_argparse_args(parser)

    opt, unknown = parser.parse_known_args()
NatalieC323's avatar
NatalieC323 committed
603
    # Veirfy the arguments are both specified
604
    if opt.name and opt.resume:
Fazzie's avatar
Fazzie committed
605
606
607
        raise ValueError("-n/--name and -r/--resume cannot be specified both."
                         "If you want to resume training in a new log folder, "
                         "use -n/--name in combination with --resume_from_checkpoint")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
608

NatalieC323's avatar
NatalieC323 committed
609
    # Check if the "resume" option is specified, resume training from the checkpoint if it is true
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
610
    ckpt = None
611
    if opt.resume:
Fazzie's avatar
Fazzie committed
612
        rank_zero_info("Resuming from {}".format(opt.resume))
613
614
615
616
617
618
619
        if not os.path.exists(opt.resume):
            raise ValueError("Cannot find {}".format(opt.resume))
        if os.path.isfile(opt.resume):
            paths = opt.resume.split("/")
            # idx = len(paths)-paths[::-1].index("logs")+1
            # logdir = "/".join(paths[:idx])
            logdir = "/".join(paths[:-2])
Fazzie's avatar
Fazzie committed
620
            rank_zero_info("logdir: {}".format(logdir))
621
622
623
624
625
626
            ckpt = opt.resume
        else:
            assert os.path.isdir(opt.resume), opt.resume
            logdir = opt.resume.rstrip("/")
            ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")

NatalieC323's avatar
NatalieC323 committed
627
        # Finds all ".yaml" configuration files in the log directory and adds them to the list of base configurations
628
629
        base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
        opt.base = base_configs + opt.base
NatalieC323's avatar
NatalieC323 committed
630
        # Gets the name of the current log directory by splitting the path and taking the last element.
631
632
633
634
635
636
        _tmp = logdir.split("/")
        nowname = _tmp[-1]
    else:
        if opt.name:
            name = "_" + opt.name
        elif opt.base:
Fazzie's avatar
Fazzie committed
637
            rank_zero_info("Using base config {}".format(opt.base))
638
639
640
641
642
643
644
645
            cfg_fname = os.path.split(opt.base[0])[-1]
            cfg_name = os.path.splitext(cfg_fname)[0]
            name = "_" + cfg_name
        else:
            name = ""
        nowname = now + name + opt.postfix
        logdir = os.path.join(opt.logdir, nowname)

NatalieC323's avatar
NatalieC323 committed
646
        # Sets the checkpoint path of the 'ckpt' option is specified
Fazzie's avatar
Fazzie committed
647
648
649
        if opt.ckpt:
            ckpt = opt.ckpt

NatalieC323's avatar
NatalieC323 committed
650
    # Create the checkpoint and configuration directories within the log directory.
651
652
    ckptdir = os.path.join(logdir, "checkpoints")
    cfgdir = os.path.join(logdir, "configs")
NatalieC323's avatar
NatalieC323 committed
653
    # Sets the seed for the random number generator to ensure reproducibility
654
655
    seed_everything(opt.seed)

NatalieC323's avatar
NatalieC323 committed
656
    # Intinalize and save configuratioon using teh OmegaConf library. 
657
658
659
660
661
662
663
664
    try:
        # init and save configs
        configs = [OmegaConf.load(cfg) for cfg in opt.base]
        cli = OmegaConf.from_dotlist(unknown)
        config = OmegaConf.merge(*configs, cli)
        lightning_config = config.pop("lightning", OmegaConf.create())
        # merge trainer cli with config
        trainer_config = lightning_config.get("trainer", OmegaConf.create())
Fazzie's avatar
Fazzie committed
665

666
667
668
        for k in nondefault_trainer_args(opt):
            trainer_config[k] = getattr(opt, k)

NatalieC323's avatar
NatalieC323 committed
669
        # Check whether the accelerator is gpu
670
671
672
673
674
675
676
677
678
679
680
681
682
683
        if not trainer_config["accelerator"] == "gpu":
            del trainer_config["accelerator"]
            cpu = True
        else:
            cpu = False
        trainer_opt = argparse.Namespace(**trainer_config)
        lightning_config.trainer = trainer_config

        # model
        use_fp16 = trainer_config.get("precision", 32) == 16
        if use_fp16:
            config.model["params"].update({"use_fp16": True})
        else:
            config.model["params"].update({"use_fp16": False})
Fazzie's avatar
Fazzie committed
684
685

        if ckpt is not None:
NatalieC323's avatar
NatalieC323 committed
686
            #If a checkpoint path is specified in the ckpt variable, the code updates the "ckpt" key in the "params" dictionary of the config.model configuration with the value of ckpt
Fazzie's avatar
Fazzie committed
687
688
            config.model["params"].update({"ckpt": ckpt})
            rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"]))
Fazzie's avatar
Fazzie committed
689

690
691
692
693
694
        model = instantiate_from_config(config.model)
        # trainer and callbacks
        trainer_kwargs = dict()

        # config the logger
NatalieC323's avatar
NatalieC323 committed
695
696
        # Default logger configs to  log training metrics during the training process.
        # These loggers are specified as targets in the dictionary, along with the configuration settings specific to each logger.
697
698
        default_logger_cfgs = {
            "wandb": {
Fazzie's avatar
Fazzie committed
699
                "target": LIGHTNING_PACK_NAME + "loggers.WandbLogger",
700
701
702
703
704
705
706
                "params": {
                    "name": nowname,
                    "save_dir": logdir,
                    "offline": opt.debug,
                    "id": nowname,
                }
            },
Fazzie's avatar
Fazzie committed
707
708
709
            "tensorboard": {
                "target": LIGHTNING_PACK_NAME + "loggers.TensorBoardLogger",
                "params": {
710
711
712
713
714
715
716
                    "save_dir": logdir,
                    "name": "diff_tb",
                    "log_graph": True
                }
            }
        }

NatalieC323's avatar
NatalieC323 committed
717
        # Set up the logger for TensorBoard
718
719
720
721
722
723
724
725
726
727
728
        default_logger_cfg = default_logger_cfgs["tensorboard"]
        if "logger" in lightning_config:
            logger_cfg = lightning_config.logger
        else:
            logger_cfg = default_logger_cfg
        logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
        trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)

        # config the strategy, defualt is ddp
        if "strategy" in trainer_config:
            strategy_cfg = trainer_config["strategy"]
Fazzie's avatar
Fazzie committed
729
            strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"]
730
731
        else:
            strategy_cfg = {
Fazzie's avatar
Fazzie committed
732
                "target": LIGHTNING_PACK_NAME + "strategies.DDPStrategy",
733
734
735
736
737
738
739
                "params": {
                    "find_unused_parameters": False
                }
            }

        trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)

NatalieC323's avatar
NatalieC323 committed
740
        # Set up ModelCheckpoint callback to save best models
741
742
743
        # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
        # specify which metric is used to determine best models
        default_modelckpt_cfg = {
Fazzie's avatar
Fazzie committed
744
            "target": LIGHTNING_PACK_NAME + "callbacks.ModelCheckpoint",
745
746
747
748
749
750
751
752
753
754
755
756
757
758
            "params": {
                "dirpath": ckptdir,
                "filename": "{epoch:06}",
                "verbose": True,
                "save_last": True,
            }
        }
        if hasattr(model, "monitor"):
            default_modelckpt_cfg["params"]["monitor"] = model.monitor
            default_modelckpt_cfg["params"]["save_top_k"] = 3

        if "modelcheckpoint" in lightning_config:
            modelckpt_cfg = lightning_config.modelcheckpoint
        else:
Fazzie's avatar
Fazzie committed
759
            modelckpt_cfg = OmegaConf.create()
760
761
762
763
        modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
        if version.parse(pl.__version__) < version.parse('1.4.0'):
            trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)

NatalieC323's avatar
NatalieC323 committed
764
        # Set up various callbacks, including logging, learning rate monitoring, and CUDA management
765
766
        # add callback which sets up log directory
        default_callbacks_cfg = {
NatalieC323's avatar
NatalieC323 committed
767
            "setup_callback": {                           # callback to set up the training
768
769
                "target": "main.SetupCallback",
                "params": {
NatalieC323's avatar
NatalieC323 committed
770
771
772
773
774
775
776
                    "resume": opt.resume,                 # resume training if applicable
                    "now": now, 
                    "logdir": logdir,                     # directory to save the log file
                    "ckptdir": ckptdir,                   # directory to save the checkpoint file
                    "cfgdir": cfgdir,                     # directory to save the configuration file
                    "config": config,                     # configuration dictionary
                    "lightning_config": lightning_config, # LightningModule configuration
777
778
                }
            },
NatalieC323's avatar
NatalieC323 committed
779
            "image_logger": {                             # callback to log image data
780
781
                "target": "main.ImageLogger",
                "params": {
NatalieC323's avatar
NatalieC323 committed
782
783
784
                    "batch_frequency": 750,               # how frequently to log images
                    "max_images": 4,                      # maximum number of images to log
                    "clamp": True                         # whether to clamp pixel values to [0,1]
785
786
                }
            },
NatalieC323's avatar
NatalieC323 committed
787
            "learning_rate_logger": {                     # callback to log learning rate
788
789
                "target": "main.LearningRateMonitor",
                "params": {
NatalieC323's avatar
NatalieC323 committed
790
791
                    "logging_interval": "step",           # logging frequency (either 'step' or 'epoch')
        # "log_momentum": True                            # whether to log momentum (currently commented out)
792
793
                }
            },
NatalieC323's avatar
NatalieC323 committed
794
            "cuda_callback": {                            # callback to handle CUDA-related operations
795
796
797
798
                "target": "main.CUDACallback"
            },
        }

NatalieC323's avatar
NatalieC323 committed
799
800
        # If the LightningModule configuration has specified callbacks, use those
        # Otherwise, create an empty OmegaConf configuration object
801
802
803
804
        if "callbacks" in lightning_config:
            callbacks_cfg = lightning_config.callbacks
        else:
            callbacks_cfg = OmegaConf.create()
NatalieC323's avatar
NatalieC323 committed
805
806
807
        
        # If the 'metrics_over_trainsteps_checkpoint' callback is specified in the
        # LightningModule configuration, update the default callbacks configuration
808
809
810
811
        if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
            print(
                'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
            default_metrics_over_trainsteps_ckpt_dict = {
Fazzie's avatar
Fazzie committed
812
813
814
815
816
817
818
819
820
821
822
                'metrics_over_trainsteps_checkpoint': {
                    "target": LIGHTNING_PACK_NAME + 'callbacks.ModelCheckpoint',
                    'params': {
                        "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
                        "filename": "{epoch:06}-{step:09}",
                        "verbose": True,
                        'save_top_k': -1,
                        'every_n_train_steps': 10000,
                        'save_weights_only': True
                    }
                }
823
824
            }
            default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
NatalieC323's avatar
NatalieC323 committed
825
826
        
        # Merge the default callbacks configuration with the specified callbacks configuration, and instantiate the callbacks
827
828
829
830
        callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)

        trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]

NatalieC323's avatar
NatalieC323 committed
831
        # Create a Trainer object with the specified command-line arguments and keyword arguments, and set the log directory
832
        trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
Fazzie's avatar
Fazzie committed
833
        trainer.logdir = logdir
834

NatalieC323's avatar
NatalieC323 committed
835
        # Create a data module based on the configuration file
836
837
838
839
840
841
        data = instantiate_from_config(config.data)
        # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
        # calling these ourselves should not be necessary but it is.
        # lightning still takes care of proper multiprocessing though
        data.prepare_data()
        data.setup()
Fazzie's avatar
Fazzie committed
842

NatalieC323's avatar
NatalieC323 committed
843
        # Print some information about the datasets in the data module
844
        for k in data.datasets:
Fazzie's avatar
Fazzie committed
845
            rank_zero_info(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
846

NatalieC323's avatar
NatalieC323 committed
847
848
        # Configure learning rate based on the batch size, base learning rate and number of GPUs
        # If scale_lr is true, calculate the learning rate based on additional factors
849
850
851
852
853
854
855
856
857
        bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
        if not cpu:
            ngpu = trainer_config["devices"]
        else:
            ngpu = 1
        if 'accumulate_grad_batches' in lightning_config.trainer:
            accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
        else:
            accumulate_grad_batches = 1
Fazzie's avatar
Fazzie committed
858
        rank_zero_info(f"accumulate_grad_batches = {accumulate_grad_batches}")
859
860
861
        lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
        if opt.scale_lr:
            model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
Fazzie's avatar
Fazzie committed
862
            rank_zero_info(
Fazzie's avatar
Fazzie committed
863
864
                "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)"
                .format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
865
866
        else:
            model.learning_rate = base_lr
Fazzie's avatar
Fazzie committed
867
868
            rank_zero_info("++++ NOT USING LR SCALING ++++")
            rank_zero_info(f"Setting learning rate to {model.learning_rate:.2e}")
869

NatalieC323's avatar
NatalieC323 committed
870
        # Allow checkpointing via USR1
871
872
873
874
875
876
877
878
879
        def melk(*args, **kwargs):
            # run all checkpoint hooks
            if trainer.global_rank == 0:
                print("Summoning checkpoint.")
                ckpt_path = os.path.join(ckptdir, "last.ckpt")
                trainer.save_checkpoint(ckpt_path)

        def divein(*args, **kwargs):
            if trainer.global_rank == 0:
Fazzie's avatar
Fazzie committed
880
                import pudb
881
882
883
                pudb.set_trace()

        import signal
NatalieC323's avatar
NatalieC323 committed
884
        # Assign melk to SIGUSR1 signal and divein to SIGUSR2 signal
885
886
887
        signal.signal(signal.SIGUSR1, melk)
        signal.signal(signal.SIGUSR2, divein)

NatalieC323's avatar
NatalieC323 committed
888
        # Run the training and validation
889
890
891
892
893
894
        if opt.train:
            try:
                trainer.fit(model, data)
            except Exception:
                melk()
                raise
NatalieC323's avatar
NatalieC323 committed
895
896
        # Print the maximum GPU memory allocated during training
        print(f"GPU memory usage: {torch.cuda.max_memory_allocated() / 1024**2:.0f} MB")
897
898
899
        # if not opt.no_test and not trainer.interrupted:
        #     trainer.test(model, data)
    except Exception:
NatalieC323's avatar
NatalieC323 committed
900
        # If there's an exception, debug it if opt.debug is true and the trainer's global rank is 0
901
902
903
904
905
906
907
908
        if opt.debug and trainer.global_rank == 0:
            try:
                import pudb as debugger
            except ImportError:
                import pdb as debugger
            debugger.post_mortem()
        raise
    finally:
NatalieC323's avatar
NatalieC323 committed
909
        #  Move the log directory to debug_runs if opt.debug is true and the trainer's global
910
911
912
913
914
915
916
        if opt.debug and not opt.resume and trainer.global_rank == 0:
            dst, name = os.path.split(logdir)
            dst = os.path.join(dst, "debug_runs", name)
            os.makedirs(os.path.split(dst)[0], exist_ok=True)
            os.rename(logdir, dst)
        if trainer.global_rank == 0:
            print(trainer.profiler.summary())