train_openfold.py 18.6 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
import argparse
import logging
import os

5
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
6
7
8
9
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"

10
import random
11
import sys
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
12
13
import time

14
import numpy as np
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
15
import pytorch_lightning as pl
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
17
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
18
from pytorch_lightning.loggers import WandbLogger
19
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
20
from pytorch_lightning.plugins.environments import SLURMEnvironment
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
21
22
23
import torch

from openfold.config import model_config
24
25
from openfold.data.data_modules import (
    OpenFoldDataModule,
26
    DummyDataLoader,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
27
)
28
from openfold.model.model import AlphaFold
29
from openfold.model.torchscript import script_preset_
30
31
from openfold.np import residue_constants
from openfold.utils.argparse import remove_arguments
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
32
33
34
from openfold.utils.callbacks import (
    EarlyStoppingVerbose,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
35
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
36
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
37
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
38
from openfold.utils.seed import seed_everything
39
from openfold.utils.superimposition import superimpose
40
from openfold.utils.tensor_utils import tensor_tree_map
41
42
43
44
45
from openfold.utils.validation_metrics import (
    drmsd,
    gdt_ts,
    gdt_ha,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
46
from scripts.zero_to_fp32 import (
47
48
    get_fp32_state_dict_from_zero_checkpoint,
    get_global_step_from_zero_checkpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
49
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
50

Marta's avatar
Marta committed
51
52
from openfold.utils.logger import PerformanceLoggingCallback

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
53
54
55
56
57

class OpenFoldWrapper(pl.LightningModule):
    def __init__(self, config):
        super(OpenFoldWrapper, self).__init__()
        self.config = config
58
        self.model = AlphaFold(config)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
59
        self.loss = AlphaFoldLoss(config.loss)
60
61
62
        self.ema = ExponentialMovingAverage(
            model=self.model, decay=config.ema.decay
        )
63
64
        
        self.cached_weights = None
65
        self.last_lr_step = -1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
66
67
68
69

    def forward(self, batch):
        return self.model(batch)

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    def _log(self, loss_breakdown, batch, outputs, train=True):
        phase = "train" if train else "val"
        for loss_name, indiv_loss in loss_breakdown.items():
            self.log(
                f"{phase}/{loss_name}", 
                indiv_loss, 
                on_step=train, on_epoch=(not train), logger=True,
            )

            if(train):
                self.log(
                    f"{phase}/{loss_name}_epoch",
                    indiv_loss,
                    on_step=False, on_epoch=True, logger=True,
                )

        with torch.no_grad():
            other_metrics = self._compute_validation_metrics(
                batch, 
                outputs,
                superimposition_metrics=(not train)
            )

        for k,v in other_metrics.items():
            self.log(
                f"{phase}/{k}", 
                v, 
                on_step=False, on_epoch=True, logger=True
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
100
    def training_step(self, batch, batch_idx):
101
102
103
        if(self.ema.device != batch["aatype"].device):
            self.ema.to(batch["aatype"].device)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
104
105
        # Run the model
        outputs = self(batch)
106

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
107
108
109
110
        # Remove the recycling dimension
        batch = tensor_tree_map(lambda t: t[..., -1], batch)

        # Compute loss
111
112
113
        loss, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
114

115
116
        # Log it
        self._log(loss_breakdown, batch, outputs)
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
117

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
118
        return loss
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
119

120
121
    def on_before_zero_grad(self, *args, **kwargs):
        self.ema.update(self.model)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
122

123
124
125
    def validation_step(self, batch, batch_idx):
        # At the start of validation, load the EMA weights
        if(self.cached_weights is None):
126
127
128
129
130
            # model.state_dict() contains references to model weights rather
            # than copies. Therefore, we need to clone them before calling 
            # load_state_dict().
            clone_param = lambda t: t.detach().clone()
            self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
131
            self.model.load_state_dict(self.ema.state_dict()["params"])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
132
       
133
        # Run the model
134
135
        outputs = self(batch)
        batch = tensor_tree_map(lambda t: t[..., -1], batch)
136
137
138
139
140

        # Compute loss and other metrics
        batch["use_clamped_fape"] = 0.
        _, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
141
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
142

143
144
        self._log(loss_breakdown, batch, outputs, train=False)
        
145
146
147
148
    def validation_epoch_end(self, _):
        # Restore the model weights to normal
        self.model.load_state_dict(self.cached_weights)
        self.cached_weights = None
149

150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    def _compute_validation_metrics(self, 
        batch, 
        outputs, 
        superimposition_metrics=False
    ):
        metrics = {}
        
        gt_coords = batch["all_atom_positions"]
        pred_coords = outputs["final_atom_positions"]
        all_atom_mask = batch["all_atom_mask"]
    
        # This is super janky for superimposition. Fix later
        gt_coords_masked = gt_coords * all_atom_mask[..., None]
        pred_coords_masked = pred_coords * all_atom_mask[..., None]
        ca_pos = residue_constants.atom_order["CA"]
        gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
        pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
        all_atom_mask_ca = all_atom_mask[..., ca_pos]
    
        lddt_ca_score = lddt_ca(
            pred_coords,
            gt_coords,
            all_atom_mask,
            eps=self.config.globals.eps,
            per_residue=False,
        )
   
        metrics["lddt_ca"] = lddt_ca_score
   
        drmsd_ca_score = drmsd(
            pred_coords_masked_ca,
            gt_coords_masked_ca,
            mask=all_atom_mask_ca, # still required here to compute n
        )
   
        metrics["drmsd_ca"] = drmsd_ca_score
    
        if(superimposition_metrics):
            superimposed_pred, alignment_rmsd = superimpose(
                gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
            )
            gdt_ts_score = gdt_ts(
                superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
            )
            gdt_ha_score = gdt_ha(
                superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
            )

            metrics["alignment_rmsd"] = alignment_rmsd
            metrics["gdt_ts"] = gdt_ts_score
            metrics["gdt_ha"] = gdt_ha_score
    
        return metrics

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
204
205
    def configure_optimizers(self, 
        learning_rate: float = 1e-3,
206
        eps: float = 1e-5,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
207
    ) -> torch.optim.Adam:
208
209
210
211
212
#        return torch.optim.Adam(
#            self.model.parameters(),
#            lr=learning_rate,
#            eps=eps
#        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
213
        # Ignored as long as a DeepSpeed optimizer is configured
214
        optimizer = torch.optim.Adam(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
215
216
217
218
            self.model.parameters(), 
            lr=learning_rate, 
            eps=eps
        )
219
220
221
222
223
224

        if self.last_lr_step != -1:
            for group in optimizer.param_groups:
                if 'initial_lr' not in group:
                    group['initial_lr'] = learning_rate

225
226
227
        lr_scheduler = AlphaFoldLRScheduler(
            optimizer,
        )
228

229
230
231
232
233
234
235
236
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "step",
                "name": "AlphaFoldLRScheduler",
            }
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
237

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
238
    def on_load_checkpoint(self, checkpoint):
239
240
241
242
        ema = checkpoint["ema"]
        if(not self.model.template_config.enabled):
            ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k}
        self.ema.load_state_dict(ema)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
243

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
244
245
246
    def on_save_checkpoint(self, checkpoint):
        checkpoint["ema"] = self.ema.state_dict()

247

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
248
def main(args):
249
250
251
    if(args.seed is not None):
        seed_everything(args.seed) 

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
252
    config = model_config(
253
        args.config_preset, 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
254
        train=True, 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
255
        low_prec=(args.precision == "16")
256
    ) 
Gustaf's avatar
Gustaf committed
257
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
258
    model_module = OpenFoldWrapper(config)
259
260
261
262
    if(args.resume_from_ckpt):
        last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
        model_module.resume_last_lr_step(last_global_step)
        logging.info("Successfully loaded last lr step...")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
263
264
265
266
267
    if(args.resume_from_ckpt and args.resume_model_weights_only):
        sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
        sd = {k[len("module."):]:v for k,v in sd.items()}
        model_module.load_state_dict(sd)
        logging.info("Successfully loaded model weights...")
268
 
269
    # TorchScript components of the model
270
271
    if(args.script_modules):
        script_preset_(model_module)
272

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
273
    #data_module = DummyDataLoader("new_batch.pickle")
274
275
276
277
278
    data_module = OpenFoldDataModule(
        config=config.data, 
        batch_seed=args.seed,
        **vars(args)
    )
279

280
281
    data_module.prepare_data()
    data_module.setup()
282
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
283
    callbacks = []
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
284
    if(args.checkpoint_every_epoch):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
285
        mc = ModelCheckpoint(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
286
            every_n_epochs=1,
287
288
            auto_insert_metric_name=False,
            save_top_k=-1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
289
290
291
292
293
        )
        callbacks.append(mc)

    if(args.early_stopping):
        es = EarlyStoppingVerbose(
294
            monitor="val/lddt_ca",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
295
296
297
            min_delta=args.min_delta,
            patience=args.patience,
            verbose=False,
298
            mode="max",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
299
300
301
302
            check_finite=True,
            strict=True,
        )
        callbacks.append(es)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
303

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
304
    if(args.log_performance):
Marta's avatar
Marta committed
305
306
        global_batch_size = args.num_nodes * args.gpus
        perf = PerformanceLoggingCallback(
Marta's avatar
Marta committed
307
            log_file=os.path.join(args.output_dir, "performance_log.json"),
Marta's avatar
Marta committed
308
309
310
            global_batch_size=global_batch_size,
        )
        callbacks.append(perf)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
311

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
312
313
314
315
    if(args.log_lr):
        lr_monitor = LearningRateMonitor(logging_interval="step")
        callbacks.append(lr_monitor)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
316
317
318
319
320
321
322
323
324
325
326
    loggers = []
    if(args.wandb):
        wdb_logger = WandbLogger(
            name=args.experiment_name,
            save_dir=args.output_dir,
            id=args.wandb_id,
            project=args.wandb_project,
            **{"entity": args.wandb_entity}
        )
        loggers.append(wdb_logger)

327
    if(args.deepspeed_config_path is not None):
328
329
330
        strategy = DeepSpeedPlugin(
            config=args.deepspeed_config_path,
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
331
332
        if(args.wandb):
            wdb_logger.experiment.save(args.deepspeed_config_path)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
333
            wdb_logger.experiment.save("openfold/config.py")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
334
    elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
335
        strategy = DDPPlugin(find_unused_parameters=False)
336
337
    else:
        strategy = None
338
339
340
341
342
343
 
    if(args.wandb):
        freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
        os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
        wdb_logger.experiment.save(f"{freeze_path}")

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
344
345
    trainer = pl.Trainer.from_argparse_args(
        args,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
346
        default_root_dir=args.output_dir,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
347
        strategy=strategy,
Marta's avatar
Marta committed
348
        callbacks=callbacks,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
349
        logger=loggers,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
350
351
352
353
354
355
356
357
358
359
360
    )

    if(args.resume_model_weights_only):
        ckpt_path = None
    else:
        ckpt_path = args.resume_from_ckpt

    trainer.fit(
        model_module, 
        datamodule=data_module,
        ckpt_path=ckpt_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
361
362
363
    )


Marta's avatar
Marta committed
364
365
366
367
368
369
370
371
372
373
def bool_type(bool_str: str):
    bool_str_lower = bool_str.lower()
    if bool_str_lower in ('false', 'f', 'no', 'n', '0'):
        return False
    elif bool_str_lower in ('true', 't', 'yes', 'y', '1'):
        return True
    else:
        raise ValueError(f'Cannot interpret {bool_str} as bool')


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
374
375
376
377
378
379
380
381
382
383
384
385
386
387
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "train_data_dir", type=str,
        help="Directory containing training mmCIF files"
    )
    parser.add_argument(
        "train_alignment_dir", type=str,
        help="Directory containing precomputed training alignments"
    )
    parser.add_argument(
        "template_mmcif_dir", type=str,
        help="Directory containing mmCIF files to search for templates"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
388
389
390
391
392
    parser.add_argument(
        "output_dir", type=str,
        help='''Directory in which to output checkpoints, logs, etc. Ignored
                if not on rank 0'''
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
393
394
    parser.add_argument(
        "max_template_date", type=str,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
395
396
        help='''Cutoff for all templates. In training mode, templates are also 
                filtered by the release date of the target'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
397
    )
398
399
400
401
402
403
404
405
    parser.add_argument(
        "--distillation_data_dir", type=str, default=None,
        help="Directory containing training PDB files"
    )
    parser.add_argument(
        "--distillation_alignment_dir", type=str, default=None,
        help="Directory containing precomputed distillation alignments"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
406
407
408
409
410
411
412
413
414
415
416
417
418
    parser.add_argument(
        "--val_data_dir", type=str, default=None,
        help="Directory containing validation mmCIF files"
    )
    parser.add_argument(
        "--val_alignment_dir", type=str, default=None,
        help="Directory containing precomputed validation alignments"
    )
    parser.add_argument(
        "--kalign_binary_path", type=str, default='/usr/bin/kalign',
        help="Path to the kalign binary"
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
419
420
421
422
        "--train_filter_path", type=str, default=None,
        help='''Optional path to a text file containing names of training
                examples to include, one per line. Used to filter the training 
                set'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
423
424
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
425
426
        "--distillation_filter_path", type=str, default=None,
        help="""See --train_filter_path"""
427
    )
428
429
430
431
432
    parser.add_argument(
        "--obsolete_pdbs_file_path", type=str, default=None,
        help="""Path to obsolete.dat file containing list of obsolete PDBs and 
             their replacements."""
    )
433
434
    parser.add_argument(
        "--template_release_dates_cache_path", type=str, default=None,
435
436
        help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
                files."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
437
438
    )
    parser.add_argument(
Marta's avatar
Marta committed
439
        "--use_small_bfd", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
440
441
442
        help="Whether to use a reduced version of the BFD database"
    )
    parser.add_argument(
443
444
        "--seed", type=int, default=None,
        help="Random seed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
445
    )
446
447
448
449
    parser.add_argument(
        "--deepspeed_config_path", type=str, default=None,
        help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
450
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
451
452
        "--checkpoint_every_epoch", action="store_true", default=False,
        help="""Whether to checkpoint at the end of every training epoch"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
453
454
    )
    parser.add_argument(
Marta's avatar
Marta committed
455
        "--early_stopping", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
456
457
458
459
460
461
462
463
464
465
466
        help="Whether to stop training when validation loss fails to decrease"
    )
    parser.add_argument(
        "--min_delta", type=float, default=0,
        help="""The smallest decrease in validation loss that counts as an 
                improvement for the purposes of early stopping"""
    )
    parser.add_argument(
        "--patience", type=int, default=3,
        help="Early stopping patience"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
467
468
469
470
471
    parser.add_argument(
        "--resume_from_ckpt", type=str, default=None,
        help="Path to a model checkpoint from which to restore training state"
    )
    parser.add_argument(
Marta's avatar
Marta committed
472
        "--resume_model_weights_only", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
473
474
        help="Whether to load just model weights as opposed to training state"
    )
Marta's avatar
Marta committed
475
    parser.add_argument(
476
        "--log_performance", type=bool_type, default=False,
Marta's avatar
Marta committed
477
478
        help="Measure performance"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
479
480
    parser.add_argument(
        "--wandb", action="store_true", default=False,
481
        help="Whether to log metrics to Weights & Biases"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
482
483
484
    )
    parser.add_argument(
        "--experiment_name", type=str, default=None,
485
        help="Name of the current experiment. Used for wandb logging"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
486
487
488
    )
    parser.add_argument(
        "--wandb_id", type=str, default=None,
489
        help="ID of a previous run to be resumed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
490
491
492
    )
    parser.add_argument(
        "--wandb_project", type=str, default=None,
493
        help="Name of the wandb project to which this run will belong"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
494
495
496
    )
    parser.add_argument(
        "--wandb_entity", type=str, default=None,
497
        help="wandb username or team name to which runs are attributed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
498
    )
499
500
501
502
    parser.add_argument(
        "--script_modules", type=bool_type, default=False,
        help="Whether to TorchScript eligible components of them model"
    )
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
503
    parser.add_argument(
504
        "--train_chain_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
505
506
    )
    parser.add_argument(
507
        "--distillation_chain_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
508
509
510
    )
    parser.add_argument(
        "--train_epoch_len", type=int, default=10000,
511
512
513
514
515
516
        help=(
            "The virtual length of each training epoch. Stochastic filtering "
            "of training data means that training datasets have no "
            "well-defined length. This virtual length affects frequency of "
            "validation & checkpointing (by default, one of each per epoch)."
        )
517
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
518
    parser.add_argument(
519
520
        "--log_lr", action="store_true", default=False,
        help="Whether to log the actual learning rate"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
521
    )
522
    parser.add_argument(
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
        "--config_preset", type=str, default="initial_training",
        help=(
            'Config setting. Choose e.g. "initial_training", "finetuning", '
            '"model_1", etc. By default, the actual values in the config are '
            'used.'
        )
    )
    parser.add_argument(
        "--_distillation_structure_index_path", type=str, default=None,
    )
    parser.add_argument(
        "--alignment_index_path", type=str, default=None,
        help="Training alignment index. See the README for instructions."
    )
    parser.add_argument(
        "--distillation_alignment_index_path", type=str, default=None,
        help="Distillation alignment index. See the README for instructions."
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
540
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
541
    parser = pl.Trainer.add_argparse_args(parser)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
542
543
   
    # Disable the initial validation pass
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
544
545
546
547
    parser.set_defaults(
        num_sanity_val_steps=0,
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
548
    # Remove some buggy/redundant arguments introduced by the Trainer
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
549
550
551
552
553
    remove_arguments(
        parser, 
        [
            "--accelerator", 
            "--resume_from_checkpoint",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
554
555
            "--reload_dataloaders_every_epoch",
            "--reload_dataloaders_every_n_epochs",
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
556
557
        ]
    ) 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
558

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
559
560
    args = parser.parse_args()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
561
562
563
564
565
    if(args.seed is None and 
        ((args.gpus is not None and args.gpus > 1) or 
         (args.num_nodes is not None and args.num_nodes > 1))):
        raise ValueError("For distributed training, --seed must be specified")

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
566
    # This re-applies the training-time filters at the beginning of every epoch
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
567
    args.reload_dataloaders_every_n_epochs = 1
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
568

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
569
    main(args)