train_openfold.py 19 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
import argparse
import logging
import os
4
import random
5
import sys
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
6
7
import time

8
import numpy as np
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
9
import pytorch_lightning as pl
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
10
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
11
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
12
from pytorch_lightning.loggers import WandbLogger
13
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
14
from pytorch_lightning.plugins.environments import SLURMEnvironment
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
15
16
17
import torch

from openfold.config import model_config
18
19
from openfold.data.data_modules import (
    OpenFoldDataModule,
20
    DummyDataLoader,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
21
)
22
from openfold.model.model import AlphaFold
23
from openfold.model.torchscript import script_preset_
24
25
from openfold.np import residue_constants
from openfold.utils.argparse import remove_arguments
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
26
27
28
from openfold.utils.callbacks import (
    EarlyStoppingVerbose,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
29
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
30
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
31
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
32
from openfold.utils.seed import seed_everything
33
from openfold.utils.superimposition import superimpose
34
from openfold.utils.tensor_utils import tensor_tree_map
35
36
37
38
39
from openfold.utils.validation_metrics import (
    drmsd,
    gdt_ts,
    gdt_ha,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
40
from scripts.zero_to_fp32 import (
41
42
    get_fp32_state_dict_from_zero_checkpoint,
    get_global_step_from_zero_checkpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
43
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
44

Marta's avatar
Marta committed
45
46
from openfold.utils.logger import PerformanceLoggingCallback

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
47
48
49
50
51

class OpenFoldWrapper(pl.LightningModule):
    def __init__(self, config):
        super(OpenFoldWrapper, self).__init__()
        self.config = config
52
        self.model = AlphaFold(config)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
53
        self.loss = AlphaFoldLoss(config.loss)
54
55
56
        self.ema = ExponentialMovingAverage(
            model=self.model, decay=config.ema.decay
        )
57
58
        
        self.cached_weights = None
59
        self.last_lr_step = -1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
60
61
62
63

    def forward(self, batch):
        return self.model(batch)

64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    def _log(self, loss_breakdown, batch, outputs, train=True):
        phase = "train" if train else "val"
        for loss_name, indiv_loss in loss_breakdown.items():
            self.log(
                f"{phase}/{loss_name}", 
                indiv_loss, 
                on_step=train, on_epoch=(not train), logger=True,
            )

            if(train):
                self.log(
                    f"{phase}/{loss_name}_epoch",
                    indiv_loss,
                    on_step=False, on_epoch=True, logger=True,
                )

        with torch.no_grad():
            other_metrics = self._compute_validation_metrics(
                batch, 
                outputs,
                superimposition_metrics=(not train)
            )

        for k,v in other_metrics.items():
            self.log(
                f"{phase}/{k}", 
                v, 
                on_step=False, on_epoch=True, logger=True
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
94
    def training_step(self, batch, batch_idx):
95
96
97
        if(self.ema.device != batch["aatype"].device):
            self.ema.to(batch["aatype"].device)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
98
99
        # Run the model
        outputs = self(batch)
100

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
101
102
103
104
        # Remove the recycling dimension
        batch = tensor_tree_map(lambda t: t[..., -1], batch)

        # Compute loss
105
106
107
        loss, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
108

109
110
        # Log it
        self._log(loss_breakdown, batch, outputs)
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
111

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
112
        return loss
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
113

114
115
    def on_before_zero_grad(self, *args, **kwargs):
        self.ema.update(self.model)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
116

117
118
119
    def validation_step(self, batch, batch_idx):
        # At the start of validation, load the EMA weights
        if(self.cached_weights is None):
120
121
122
123
124
            # model.state_dict() contains references to model weights rather
            # than copies. Therefore, we need to clone them before calling 
            # load_state_dict().
            clone_param = lambda t: t.detach().clone()
            self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
125
            self.model.load_state_dict(self.ema.state_dict()["params"])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
126
       
127
        # Run the model
128
129
        outputs = self(batch)
        batch = tensor_tree_map(lambda t: t[..., -1], batch)
130
131
132
133
134

        # Compute loss and other metrics
        batch["use_clamped_fape"] = 0.
        _, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
135
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
136

137
138
        self._log(loss_breakdown, batch, outputs, train=False)
        
139
140
141
142
    def validation_epoch_end(self, _):
        # Restore the model weights to normal
        self.model.load_state_dict(self.cached_weights)
        self.cached_weights = None
143

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    def _compute_validation_metrics(self, 
        batch, 
        outputs, 
        superimposition_metrics=False
    ):
        metrics = {}
        
        gt_coords = batch["all_atom_positions"]
        pred_coords = outputs["final_atom_positions"]
        all_atom_mask = batch["all_atom_mask"]
    
        # This is super janky for superimposition. Fix later
        gt_coords_masked = gt_coords * all_atom_mask[..., None]
        pred_coords_masked = pred_coords * all_atom_mask[..., None]
        ca_pos = residue_constants.atom_order["CA"]
        gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
        pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
        all_atom_mask_ca = all_atom_mask[..., ca_pos]
    
        lddt_ca_score = lddt_ca(
            pred_coords,
            gt_coords,
            all_atom_mask,
            eps=self.config.globals.eps,
            per_residue=False,
        )
   
        metrics["lddt_ca"] = lddt_ca_score
   
        drmsd_ca_score = drmsd(
            pred_coords_masked_ca,
            gt_coords_masked_ca,
            mask=all_atom_mask_ca, # still required here to compute n
        )
   
        metrics["drmsd_ca"] = drmsd_ca_score
    
        if(superimposition_metrics):
            superimposed_pred, alignment_rmsd = superimpose(
                gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
            )
            gdt_ts_score = gdt_ts(
                superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
            )
            gdt_ha_score = gdt_ha(
                superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
            )

            metrics["alignment_rmsd"] = alignment_rmsd
            metrics["gdt_ts"] = gdt_ts_score
            metrics["gdt_ha"] = gdt_ha_score
    
        return metrics

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
198
199
    def configure_optimizers(self, 
        learning_rate: float = 1e-3,
200
        eps: float = 1e-5,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
201
    ) -> torch.optim.Adam:
202
203
204
205
206
#        return torch.optim.Adam(
#            self.model.parameters(),
#            lr=learning_rate,
#            eps=eps
#        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
207
        # Ignored as long as a DeepSpeed optimizer is configured
208
        optimizer = torch.optim.Adam(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
209
210
211
212
            self.model.parameters(), 
            lr=learning_rate, 
            eps=eps
        )
213
214
215
216
217
218

        if self.last_lr_step != -1:
            for group in optimizer.param_groups:
                if 'initial_lr' not in group:
                    group['initial_lr'] = learning_rate

219
220
221
        lr_scheduler = AlphaFoldLRScheduler(
            optimizer,
        )
222

223
224
225
226
227
228
229
230
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "step",
                "name": "AlphaFoldLRScheduler",
            }
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
231

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
232
    def on_load_checkpoint(self, checkpoint):
233
234
235
236
        ema = checkpoint["ema"]
        if(not self.model.template_config.enabled):
            ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k}
        self.ema.load_state_dict(ema)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
237

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
238
239
240
    def on_save_checkpoint(self, checkpoint):
        checkpoint["ema"] = self.ema.state_dict()

241
242
243
    def resume_last_lr_step(self, lr_step):
        self.last_lr_step = lr_step

244

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
245
def main(args):
246
247
248
    if(args.seed is not None):
        seed_everything(args.seed) 

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
249
    config = model_config(
250
        args.config_preset, 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
251
        train=True, 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
252
        low_prec=(str(args.precision) == "16")
253
    ) 
Gustaf's avatar
Gustaf committed
254
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
255
    model_module = OpenFoldWrapper(config)
256
    if(args.resume_from_ckpt):
257
258
259
260
261
        if(os.path.isdir(args.resume_from_ckpt)):  
            last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
        else:
            sd = torch.load(args.resume_from_ckpt)
            last_global_step = int(sd['global_step'])
262
263
        model_module.resume_last_lr_step(last_global_step)
        logging.info("Successfully loaded last lr step...")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
264
    if(args.resume_from_ckpt and args.resume_model_weights_only):
265
266
267
268
        if(os.path.isdir(args.resume_from_ckpt)):
            sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
        else:
            sd = torch.load(args.resume_from_ckpt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
269
270
271
        sd = {k[len("module."):]:v for k,v in sd.items()}
        model_module.load_state_dict(sd)
        logging.info("Successfully loaded model weights...")
272
 
273
    # TorchScript components of the model
274
275
    if(args.script_modules):
        script_preset_(model_module)
276

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
277
    #data_module = DummyDataLoader("new_batch.pickle")
278
279
280
281
282
    data_module = OpenFoldDataModule(
        config=config.data, 
        batch_seed=args.seed,
        **vars(args)
    )
283

284
285
    data_module.prepare_data()
    data_module.setup()
286
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
287
    callbacks = []
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
288
    if(args.checkpoint_every_epoch):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
289
        mc = ModelCheckpoint(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
290
            every_n_epochs=1,
291
292
            auto_insert_metric_name=False,
            save_top_k=-1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
293
294
295
296
297
        )
        callbacks.append(mc)

    if(args.early_stopping):
        es = EarlyStoppingVerbose(
298
            monitor="val/lddt_ca",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
299
300
301
            min_delta=args.min_delta,
            patience=args.patience,
            verbose=False,
302
            mode="max",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
303
304
305
306
            check_finite=True,
            strict=True,
        )
        callbacks.append(es)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
307

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
308
    if(args.log_performance):
Marta's avatar
Marta committed
309
310
        global_batch_size = args.num_nodes * args.gpus
        perf = PerformanceLoggingCallback(
Marta's avatar
Marta committed
311
            log_file=os.path.join(args.output_dir, "performance_log.json"),
Marta's avatar
Marta committed
312
313
314
            global_batch_size=global_batch_size,
        )
        callbacks.append(perf)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
315

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
316
317
318
319
    if(args.log_lr):
        lr_monitor = LearningRateMonitor(logging_interval="step")
        callbacks.append(lr_monitor)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
320
321
322
323
324
325
326
327
328
329
330
    loggers = []
    if(args.wandb):
        wdb_logger = WandbLogger(
            name=args.experiment_name,
            save_dir=args.output_dir,
            id=args.wandb_id,
            project=args.wandb_project,
            **{"entity": args.wandb_entity}
        )
        loggers.append(wdb_logger)

331
    if(args.deepspeed_config_path is not None):
332
333
334
        strategy = DeepSpeedPlugin(
            config=args.deepspeed_config_path,
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
335
336
        if(args.wandb):
            wdb_logger.experiment.save(args.deepspeed_config_path)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
337
            wdb_logger.experiment.save("openfold/config.py")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
338
    elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
339
        strategy = DDPPlugin(find_unused_parameters=False)
340
341
    else:
        strategy = None
342
343
344
345
346
347
 
    if(args.wandb):
        freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
        os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
        wdb_logger.experiment.save(f"{freeze_path}")

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
348
349
    trainer = pl.Trainer.from_argparse_args(
        args,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
350
        default_root_dir=args.output_dir,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
351
        strategy=strategy,
Marta's avatar
Marta committed
352
        callbacks=callbacks,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
353
        logger=loggers,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
354
355
356
357
358
359
360
361
362
363
364
    )

    if(args.resume_model_weights_only):
        ckpt_path = None
    else:
        ckpt_path = args.resume_from_ckpt

    trainer.fit(
        model_module, 
        datamodule=data_module,
        ckpt_path=ckpt_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
365
366
367
    )


Marta's avatar
Marta committed
368
369
370
371
372
373
374
375
376
377
def bool_type(bool_str: str):
    bool_str_lower = bool_str.lower()
    if bool_str_lower in ('false', 'f', 'no', 'n', '0'):
        return False
    elif bool_str_lower in ('true', 't', 'yes', 'y', '1'):
        return True
    else:
        raise ValueError(f'Cannot interpret {bool_str} as bool')


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "train_data_dir", type=str,
        help="Directory containing training mmCIF files"
    )
    parser.add_argument(
        "train_alignment_dir", type=str,
        help="Directory containing precomputed training alignments"
    )
    parser.add_argument(
        "template_mmcif_dir", type=str,
        help="Directory containing mmCIF files to search for templates"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
392
393
394
395
396
    parser.add_argument(
        "output_dir", type=str,
        help='''Directory in which to output checkpoints, logs, etc. Ignored
                if not on rank 0'''
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
397
398
    parser.add_argument(
        "max_template_date", type=str,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
399
400
        help='''Cutoff for all templates. In training mode, templates are also 
                filtered by the release date of the target'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
401
    )
402
403
404
405
406
407
408
409
    parser.add_argument(
        "--distillation_data_dir", type=str, default=None,
        help="Directory containing training PDB files"
    )
    parser.add_argument(
        "--distillation_alignment_dir", type=str, default=None,
        help="Directory containing precomputed distillation alignments"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
410
411
412
413
414
415
416
417
418
419
420
421
422
    parser.add_argument(
        "--val_data_dir", type=str, default=None,
        help="Directory containing validation mmCIF files"
    )
    parser.add_argument(
        "--val_alignment_dir", type=str, default=None,
        help="Directory containing precomputed validation alignments"
    )
    parser.add_argument(
        "--kalign_binary_path", type=str, default='/usr/bin/kalign',
        help="Path to the kalign binary"
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
423
424
425
426
        "--train_filter_path", type=str, default=None,
        help='''Optional path to a text file containing names of training
                examples to include, one per line. Used to filter the training 
                set'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
427
428
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
429
430
        "--distillation_filter_path", type=str, default=None,
        help="""See --train_filter_path"""
431
    )
432
433
434
435
436
    parser.add_argument(
        "--obsolete_pdbs_file_path", type=str, default=None,
        help="""Path to obsolete.dat file containing list of obsolete PDBs and 
             their replacements."""
    )
437
438
    parser.add_argument(
        "--template_release_dates_cache_path", type=str, default=None,
439
440
        help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
                files."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
441
442
    )
    parser.add_argument(
Marta's avatar
Marta committed
443
        "--use_small_bfd", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
444
445
446
        help="Whether to use a reduced version of the BFD database"
    )
    parser.add_argument(
447
448
        "--seed", type=int, default=None,
        help="Random seed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
449
    )
450
451
452
453
    parser.add_argument(
        "--deepspeed_config_path", type=str, default=None,
        help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
454
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
455
456
        "--checkpoint_every_epoch", action="store_true", default=False,
        help="""Whether to checkpoint at the end of every training epoch"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
457
458
    )
    parser.add_argument(
Marta's avatar
Marta committed
459
        "--early_stopping", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
460
461
462
463
464
465
466
467
468
469
470
        help="Whether to stop training when validation loss fails to decrease"
    )
    parser.add_argument(
        "--min_delta", type=float, default=0,
        help="""The smallest decrease in validation loss that counts as an 
                improvement for the purposes of early stopping"""
    )
    parser.add_argument(
        "--patience", type=int, default=3,
        help="Early stopping patience"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
471
472
473
474
475
    parser.add_argument(
        "--resume_from_ckpt", type=str, default=None,
        help="Path to a model checkpoint from which to restore training state"
    )
    parser.add_argument(
Marta's avatar
Marta committed
476
        "--resume_model_weights_only", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
477
478
        help="Whether to load just model weights as opposed to training state"
    )
Marta's avatar
Marta committed
479
    parser.add_argument(
480
        "--log_performance", type=bool_type, default=False,
Marta's avatar
Marta committed
481
482
        help="Measure performance"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
483
484
    parser.add_argument(
        "--wandb", action="store_true", default=False,
485
        help="Whether to log metrics to Weights & Biases"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
486
487
488
    )
    parser.add_argument(
        "--experiment_name", type=str, default=None,
489
        help="Name of the current experiment. Used for wandb logging"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
490
491
492
    )
    parser.add_argument(
        "--wandb_id", type=str, default=None,
493
        help="ID of a previous run to be resumed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
494
495
496
    )
    parser.add_argument(
        "--wandb_project", type=str, default=None,
497
        help="Name of the wandb project to which this run will belong"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
498
499
500
    )
    parser.add_argument(
        "--wandb_entity", type=str, default=None,
501
        help="wandb username or team name to which runs are attributed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
502
    )
503
504
505
506
    parser.add_argument(
        "--script_modules", type=bool_type, default=False,
        help="Whether to TorchScript eligible components of them model"
    )
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
507
    parser.add_argument(
508
        "--train_chain_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
509
510
    )
    parser.add_argument(
511
        "--distillation_chain_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
512
513
514
    )
    parser.add_argument(
        "--train_epoch_len", type=int, default=10000,
515
516
517
518
519
520
        help=(
            "The virtual length of each training epoch. Stochastic filtering "
            "of training data means that training datasets have no "
            "well-defined length. This virtual length affects frequency of "
            "validation & checkpointing (by default, one of each per epoch)."
        )
521
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
522
    parser.add_argument(
523
524
        "--log_lr", action="store_true", default=False,
        help="Whether to log the actual learning rate"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
525
    )
526
    parser.add_argument(
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
        "--config_preset", type=str, default="initial_training",
        help=(
            'Config setting. Choose e.g. "initial_training", "finetuning", '
            '"model_1", etc. By default, the actual values in the config are '
            'used.'
        )
    )
    parser.add_argument(
        "--_distillation_structure_index_path", type=str, default=None,
    )
    parser.add_argument(
        "--alignment_index_path", type=str, default=None,
        help="Training alignment index. See the README for instructions."
    )
    parser.add_argument(
        "--distillation_alignment_index_path", type=str, default=None,
        help="Distillation alignment index. See the README for instructions."
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
544
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
545
    parser = pl.Trainer.add_argparse_args(parser)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
546
547
   
    # Disable the initial validation pass
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
548
549
550
551
    parser.set_defaults(
        num_sanity_val_steps=0,
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
552
    # Remove some buggy/redundant arguments introduced by the Trainer
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
553
554
555
556
557
    remove_arguments(
        parser, 
        [
            "--accelerator", 
            "--resume_from_checkpoint",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
558
559
            "--reload_dataloaders_every_epoch",
            "--reload_dataloaders_every_n_epochs",
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
560
561
        ]
    ) 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
562

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
563
564
    args = parser.parse_args()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
565
566
567
568
569
    if(args.seed is None and 
        ((args.gpus is not None and args.gpus > 1) or 
         (args.num_nodes is not None and args.num_nodes > 1))):
        raise ValueError("For distributed training, --seed must be specified")

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
570
571
572
    if(args.precision == "16" and args.deepspeed_config_path is not None):
        raise ValueError("DeepSpeed and FP16 training are not compatible")

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
573
    # This re-applies the training-time filters at the beginning of every epoch
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
574
    args.reload_dataloaders_every_n_epochs = 1
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
575

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
576
    main(args)