train_openfold.py 24 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
import argparse
import logging
import os
4
import sys
5
import json
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
6
7

import pytorch_lightning as pl
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
8
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
9
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
10
from pytorch_lightning.loggers import WandbLogger
Jennifer Wei's avatar
Jennifer Wei committed
11
12
13
from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy
from pytorch_lightning.plugins.environments import MPIEnvironment
from pytorch_lightning import seed_everything
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
14
import torch
Jennifer Wei's avatar
Jennifer Wei committed
15
import wandb
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
17

from openfold.config import model_config
18
from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
19
from openfold.model.model import AlphaFold
20
from openfold.model.torchscript import script_preset_
21
from openfold.np import residue_constants
22
from openfold.utils.argparse_utils import remove_arguments
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
23
24
25
from openfold.utils.callbacks import (
    EarlyStoppingVerbose,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
26
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
27
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
28
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
29
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
30
from openfold.utils.superimposition import superimpose
31
from openfold.utils.tensor_utils import tensor_tree_map
32
33
34
35
36
from openfold.utils.validation_metrics import (
    drmsd,
    gdt_ts,
    gdt_ha,
)
37
38
from openfold.utils.import_weights import (
    import_jax_weights_,
39
    import_openfold_weights_
40
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
41
from scripts.zero_to_fp32 import (
42
43
    get_fp32_state_dict_from_zero_checkpoint,
    get_global_step_from_zero_checkpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
44
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
45

Marta's avatar
Marta committed
46
47
from openfold.utils.logger import PerformanceLoggingCallback

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
48
49
50
51
52

class OpenFoldWrapper(pl.LightningModule):
    def __init__(self, config):
        super(OpenFoldWrapper, self).__init__()
        self.config = config
53
        self.model = AlphaFold(config)
54
        self.is_multimer = self.config.globals.is_multimer
55

56
        self.loss = AlphaFoldLoss(config.loss)
57

58
59
60
        self.ema = ExponentialMovingAverage(
            model=self.model, decay=config.ema.decay
        )
61
62
        
        self.cached_weights = None
63
        self.last_lr_step = -1
Jennifer Wei's avatar
Jennifer Wei committed
64
        self.save_hyperparameters()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
65
66
67
68

    def forward(self, batch):
        return self.model(batch)

69
70
71
72
73
74
    def _log(self, loss_breakdown, batch, outputs, train=True):
        phase = "train" if train else "val"
        for loss_name, indiv_loss in loss_breakdown.items():
            self.log(
                f"{phase}/{loss_name}", 
                indiv_loss, 
Jennifer Wei's avatar
Jennifer Wei committed
75
76
                prog_bar=(loss_name == 'loss'),
                on_step=train, on_epoch=(not train), logger=True, sync_dist=False,
77
78
79
80
81
82
            )

            if(train):
                self.log(
                    f"{phase}/{loss_name}_epoch",
                    indiv_loss,
Jennifer Wei's avatar
Jennifer Wei committed
83
                    on_step=False, on_epoch=True, logger=True, sync_dist=False,
84
85
86
87
88
89
90
91
92
93
94
                )

        with torch.no_grad():
            other_metrics = self._compute_validation_metrics(
                batch, 
                outputs,
                superimposition_metrics=(not train)
            )

        for k,v in other_metrics.items():
            self.log(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
95
96
                f"{phase}/{k}",
                torch.mean(v),
Jennifer Wei's avatar
Jennifer Wei committed
97
98
                prog_bar = (k == 'loss'),
                on_step=False, on_epoch=True, logger=True, sync_dist=False,
99
100
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
101
    def training_step(self, batch, batch_idx):
102
103
104
        if(self.ema.device != batch["aatype"].device):
            self.ema.to(batch["aatype"].device)

105
106
        ground_truth = batch.pop('gt_features', None)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
107
108
        # Run the model
        outputs = self(batch)
109

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
110
111
112
        # Remove the recycling dimension
        batch = tensor_tree_map(lambda t: t[..., -1], batch)

113
114
115
116
117
        if self.is_multimer:
            batch = multi_chain_permutation_align(out=outputs,
                                                  features=batch,
                                                  ground_truth=ground_truth)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
118
        # Compute loss
119
120
121
        loss, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
122

123
124
        # Log it
        self._log(loss_breakdown, batch, outputs)
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
125

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
126
        return loss
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
127

128
129
    def on_before_zero_grad(self, *args, **kwargs):
        self.ema.update(self.model)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
130

131
132
133
    def validation_step(self, batch, batch_idx):
        # At the start of validation, load the EMA weights
        if(self.cached_weights is None):
134
135
136
137
138
            # model.state_dict() contains references to model weights rather
            # than copies. Therefore, we need to clone them before calling 
            # load_state_dict().
            clone_param = lambda t: t.detach().clone()
            self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
139
            self.model.load_state_dict(self.ema.state_dict()["params"])
140
141
142

        ground_truth = batch.pop('gt_features', None)

143
        # Run the model
144
145
        outputs = self(batch)
        batch = tensor_tree_map(lambda t: t[..., -1], batch)
146
147

        batch["use_clamped_fape"] = 0.
148
149
150
151
152
153
154

        if self.is_multimer:
            batch = multi_chain_permutation_align(out=outputs,
                                                  features=batch,
                                                  ground_truth=ground_truth)

        # Compute loss and other metrics
155
156
        _, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
157
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
158

159
160
        self._log(loss_breakdown, batch, outputs, train=False)
        
Jennifer Wei's avatar
Jennifer Wei committed
161
    def on_validation_epoch_end(self):
162
163
164
        # Restore the model weights to normal
        self.model.load_state_dict(self.cached_weights)
        self.cached_weights = None
165

166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    def _compute_validation_metrics(self, 
        batch, 
        outputs, 
        superimposition_metrics=False
    ):
        metrics = {}
        
        gt_coords = batch["all_atom_positions"]
        pred_coords = outputs["final_atom_positions"]
        all_atom_mask = batch["all_atom_mask"]
    
        # This is super janky for superimposition. Fix later
        gt_coords_masked = gt_coords * all_atom_mask[..., None]
        pred_coords_masked = pred_coords * all_atom_mask[..., None]
        ca_pos = residue_constants.atom_order["CA"]
        gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
        pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
        all_atom_mask_ca = all_atom_mask[..., ca_pos]
    
        lddt_ca_score = lddt_ca(
            pred_coords,
            gt_coords,
            all_atom_mask,
            eps=self.config.globals.eps,
            per_residue=False,
        )
   
        metrics["lddt_ca"] = lddt_ca_score
   
        drmsd_ca_score = drmsd(
            pred_coords_masked_ca,
            gt_coords_masked_ca,
            mask=all_atom_mask_ca, # still required here to compute n
        )
   
        metrics["drmsd_ca"] = drmsd_ca_score
    
        if(superimposition_metrics):
            superimposed_pred, alignment_rmsd = superimpose(
                gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
            )
            gdt_ts_score = gdt_ts(
                superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
            )
            gdt_ha_score = gdt_ha(
                superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
            )

            metrics["alignment_rmsd"] = alignment_rmsd
            metrics["gdt_ts"] = gdt_ts_score
            metrics["gdt_ha"] = gdt_ha_score
    
        return metrics

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
220
221
    def configure_optimizers(self, 
        learning_rate: float = 1e-3,
222
        eps: float = 1e-5,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
223
    ) -> torch.optim.Adam:
224
225
226
227
228
#        return torch.optim.Adam(
#            self.model.parameters(),
#            lr=learning_rate,
#            eps=eps
#        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
229
        # Ignored as long as a DeepSpeed optimizer is configured
230
        optimizer = torch.optim.Adam(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
231
232
233
234
            self.model.parameters(), 
            lr=learning_rate, 
            eps=eps
        )
235
236
237
238
239
240

        if self.last_lr_step != -1:
            for group in optimizer.param_groups:
                if 'initial_lr' not in group:
                    group['initial_lr'] = learning_rate

241
242
        lr_scheduler = AlphaFoldLRScheduler(
            optimizer,
243
            last_epoch=self.last_lr_step
244
        )
245

246
247
248
249
250
251
252
253
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "step",
                "name": "AlphaFoldLRScheduler",
            }
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
254

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
255
    def on_load_checkpoint(self, checkpoint):
256
257
258
259
        ema = checkpoint["ema"]
        if(not self.model.template_config.enabled):
            ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k}
        self.ema.load_state_dict(ema)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
260

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
261
262
263
    def on_save_checkpoint(self, checkpoint):
        checkpoint["ema"] = self.ema.state_dict()

264
265
266
    def resume_last_lr_step(self, lr_step):
        self.last_lr_step = lr_step

267
268
269
270
271
272
273
274
275
276
277
    def load_from_jax(self, jax_path):
        model_basename = os.path.splitext(
                os.path.basename(
                    os.path.normpath(jax_path)
                )
        )[0]
        model_version = "_".join(model_basename.split("_")[1:])
        import_jax_weights_(
                self.model, jax_path, version=model_version
        )

278

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
279
def main(args):
280
281
282
    if(args.seed is not None):
        seed_everything(args.seed) 

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
283
    config = model_config(
284
        args.config_preset, 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
285
        train=True, 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
286
        low_prec=(str(args.precision) == "16")
287
    ) 
288
289
290
291
292
    if args.experiment_config_json: 
        with open(args.experiment_config_json, 'r') as f:
            custom_config_dict = json.load(f)
        config.update_from_flattened_dict(custom_config_dict)

293
294
    model_module = OpenFoldWrapper(config)

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    if args.resume_from_ckpt:
        if args.resume_model_weights_only:
            # Load the checkpoint
            if os.path.isdir(args.resume_from_ckpt):
                sd = get_fp32_state_dict_from_zero_checkpoint(
                    args.resume_from_ckpt)
            else:
                sd = torch.load(args.resume_from_ckpt)
            # Process the state dict
            if 'module' in sd:
                sd = {k[len('module.'):]: v for k, v in sd['module'].items()}
                import_openfold_weights_(model=model_module, state_dict=sd)
            elif 'state_dict' in sd:
                import_openfold_weights_(
                    model=model_module, state_dict=sd['state_dict'])
            else:
                # Loading from pre-trained model
                sd = {'model.'+k: v for k, v in sd.items()}
                import_openfold_weights_(model=model_module, state_dict=sd)
            logging.info("Successfully loaded model weights...")

        else:  # Loads a checkpoint to start from a specific time step
            if os.path.isdir(args.resume_from_ckpt):
                last_global_step = get_global_step_from_zero_checkpoint(
                    args.resume_from_ckpt)
            else:
                sd = torch.load(args.resume_from_ckpt)
                last_global_step = int(sd['global_step'])
            model_module.resume_last_lr_step(last_global_step)
            logging.info("Successfully loaded last lr step...")

    if args.resume_from_jax_params:
Lucas Bickmann's avatar
Lucas Bickmann committed
327
328
        model_module.load_from_jax(args.resume_from_jax_params)
        logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
329
 
330
    # TorchScript components of the model
331
332
    if(args.script_modules):
        script_preset_(model_module)
333

334
335
    if "multimer" in args.config_preset:
        data_module = OpenFoldMultimerDataModule(
336
337
338
339
        config=config.data, 
        batch_seed=args.seed,
        **vars(args)
    )
340
341
342
343
344
345
    else:
        data_module = OpenFoldDataModule(
            config=config.data, 
            batch_seed=args.seed,
            **vars(args)
        )
346

347
348
    data_module.prepare_data()
    data_module.setup()
349
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
350
    callbacks = []
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
351
    if(args.checkpoint_every_epoch):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
352
        mc = ModelCheckpoint(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
353
            every_n_epochs=1,
354
355
            auto_insert_metric_name=False,
            save_top_k=-1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
356
357
358
359
360
        )
        callbacks.append(mc)

    if(args.early_stopping):
        es = EarlyStoppingVerbose(
361
            monitor="val/lddt_ca",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
362
363
364
            min_delta=args.min_delta,
            patience=args.patience,
            verbose=False,
365
            mode="max",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
366
367
368
369
            check_finite=True,
            strict=True,
        )
        callbacks.append(es)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
370

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
371
    if(args.log_performance):
Marta's avatar
Marta committed
372
373
        global_batch_size = args.num_nodes * args.gpus
        perf = PerformanceLoggingCallback(
Marta's avatar
Marta committed
374
            log_file=os.path.join(args.output_dir, "performance_log.json"),
Marta's avatar
Marta committed
375
376
377
            global_batch_size=global_batch_size,
        )
        callbacks.append(perf)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
378

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
379
380
381
382
    if(args.log_lr):
        lr_monitor = LearningRateMonitor(logging_interval="step")
        callbacks.append(lr_monitor)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
383
    loggers = []
Jennifer Wei's avatar
Jennifer Wei committed
384
    is_rank_zero = int(os.environ.get("PMI_RANK")) == 0
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
385
    if(args.wandb):
Jennifer Wei's avatar
Jennifer Wei committed
386
387
388
389
390
391
392
393
394
395
396
397
        if args.mpi_plugin and is_rank_zero:
            wandb_init_dict = dict(
                name=args.experiment_name,
                project=args.wandb_project,
                id=args.wandb_id,
                dir=args.output_dir,
                resume="allow",
                anonymous=None,
                entity=args.wandb_entity
            )
            wandb.run = wandb.init(**wandb_init_dict)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
398
399
400
401
402
403
404
405
406
        wdb_logger = WandbLogger(
            name=args.experiment_name,
            save_dir=args.output_dir,
            id=args.wandb_id,
            project=args.wandb_project,
            **{"entity": args.wandb_entity}
        )
        loggers.append(wdb_logger)

Jennifer Wei's avatar
Jennifer Wei committed
407
    cluster_environment = MPIEnvironment() if args.mpi_plugin else None
408
    if(args.deepspeed_config_path is not None):
Jennifer Wei's avatar
Jennifer Wei committed
409
        strategy = DeepSpeedStrategy(
410
            config=args.deepspeed_config_path,
Jennifer Wei's avatar
Jennifer Wei committed
411
            cluster_environment=cluster_environment,
412
        )
Jennifer Wei's avatar
Jennifer Wei committed
413
        if(args.wandb and is_rank_zero):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
414
            wdb_logger.experiment.save(args.deepspeed_config_path)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
415
            wdb_logger.experiment.save("openfold/config.py")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
416
    elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
Jennifer Wei's avatar
Jennifer Wei committed
417
418
        strategy = DDPStrategy(find_unused_parameters=False,
                               cluster_environment=cluster_environment)
419
420
    else:
        strategy = None
421
 
Jennifer Wei's avatar
Jennifer Wei committed
422
    if(args.wandb and is_rank_zero):
423
424
425
426
        freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
        os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
        wdb_logger.experiment.save(f"{freeze_path}")

Jennifer Wei's avatar
Jennifer Wei committed
427
428
429
430
431
    trainer = pl.Trainer(
        num_nodes=args.num_nodes,
        devices=args.gpus,
        precision=args.precision,
        max_epochs=args.max_epochs,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
432
        default_root_dir=args.output_dir,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
433
        strategy=strategy,
Marta's avatar
Marta committed
434
        callbacks=callbacks,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
435
        logger=loggers,
Jennifer Wei's avatar
Jennifer Wei committed
436
        profiler='simple',
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
437
438
439
440
441
442
443
444
445
446
447
    )

    if(args.resume_model_weights_only):
        ckpt_path = None
    else:
        ckpt_path = args.resume_from_ckpt

    trainer.fit(
        model_module, 
        datamodule=data_module,
        ckpt_path=ckpt_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
448
449
450
    )


Marta's avatar
Marta committed
451
452
453
454
455
456
457
458
459
460
def bool_type(bool_str: str):
    bool_str_lower = bool_str.lower()
    if bool_str_lower in ('false', 'f', 'no', 'n', '0'):
        return False
    elif bool_str_lower in ('true', 't', 'yes', 'y', '1'):
        return True
    else:
        raise ValueError(f'Cannot interpret {bool_str} as bool')


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
461
462
463
464
465
466
467
468
469
470
471
472
473
474
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "train_data_dir", type=str,
        help="Directory containing training mmCIF files"
    )
    parser.add_argument(
        "train_alignment_dir", type=str,
        help="Directory containing precomputed training alignments"
    )
    parser.add_argument(
        "template_mmcif_dir", type=str,
        help="Directory containing mmCIF files to search for templates"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
475
476
477
478
479
    parser.add_argument(
        "output_dir", type=str,
        help='''Directory in which to output checkpoints, logs, etc. Ignored
                if not on rank 0'''
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
480
481
    parser.add_argument(
        "max_template_date", type=str,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
482
483
        help='''Cutoff for all templates. In training mode, templates are also 
                filtered by the release date of the target'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
484
    )
485
486
    parser.add_argument(
        "--train_mmcif_data_cache_path", type=str, default=None,
487
488
        help="Path to the json file which records all the information of mmcif structures used during training"
    )
489
    parser.add_argument(
490
        "--use_single_seq_mode", type=str, default=False,
491
        help="Use single sequence embeddings instead of MSAs."
492
    )
493
494
495
496
497
498
499
500
    parser.add_argument(
        "--distillation_data_dir", type=str, default=None,
        help="Directory containing training PDB files"
    )
    parser.add_argument(
        "--distillation_alignment_dir", type=str, default=None,
        help="Directory containing precomputed distillation alignments"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
501
502
503
504
505
506
507
508
    parser.add_argument(
        "--val_data_dir", type=str, default=None,
        help="Directory containing validation mmCIF files"
    )
    parser.add_argument(
        "--val_alignment_dir", type=str, default=None,
        help="Directory containing precomputed validation alignments"
    )
509
510
    parser.add_argument(
        "--val_mmcif_data_cache_path", type=str, default=None,
Dingquan Yu's avatar
Dingquan Yu committed
511
        help="path to the json file which records all the information of mmcif structures used during validation"
512
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
513
514
515
516
517
    parser.add_argument(
        "--kalign_binary_path", type=str, default='/usr/bin/kalign',
        help="Path to the kalign binary"
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
518
519
520
521
        "--train_filter_path", type=str, default=None,
        help='''Optional path to a text file containing names of training
                examples to include, one per line. Used to filter the training 
                set'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
522
523
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
524
525
        "--distillation_filter_path", type=str, default=None,
        help="""See --train_filter_path"""
526
    )
527
528
529
530
531
    parser.add_argument(
        "--obsolete_pdbs_file_path", type=str, default=None,
        help="""Path to obsolete.dat file containing list of obsolete PDBs and 
             their replacements."""
    )
532
533
    parser.add_argument(
        "--template_release_dates_cache_path", type=str, default=None,
534
535
        help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
                files."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
536
537
    )
    parser.add_argument(
Marta's avatar
Marta committed
538
        "--use_small_bfd", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
539
540
541
        help="Whether to use a reduced version of the BFD database"
    )
    parser.add_argument(
542
543
        "--seed", type=int, default=None,
        help="Random seed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
544
    )
545
546
547
548
    parser.add_argument(
        "--deepspeed_config_path", type=str, default=None,
        help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
549
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
550
551
        "--checkpoint_every_epoch", action="store_true", default=False,
        help="""Whether to checkpoint at the end of every training epoch"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
552
553
    )
    parser.add_argument(
Marta's avatar
Marta committed
554
        "--early_stopping", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
555
556
557
558
559
560
561
562
563
564
565
        help="Whether to stop training when validation loss fails to decrease"
    )
    parser.add_argument(
        "--min_delta", type=float, default=0,
        help="""The smallest decrease in validation loss that counts as an 
                improvement for the purposes of early stopping"""
    )
    parser.add_argument(
        "--patience", type=int, default=3,
        help="Early stopping patience"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
566
567
568
569
570
    parser.add_argument(
        "--resume_from_ckpt", type=str, default=None,
        help="Path to a model checkpoint from which to restore training state"
    )
    parser.add_argument(
Marta's avatar
Marta committed
571
        "--resume_model_weights_only", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
572
573
        help="Whether to load just model weights as opposed to training state"
    )
Lucas Bickmann's avatar
Lucas Bickmann committed
574
    parser.add_argument(
575
576
        "--resume_from_jax_params", type=str, default=None,
        help="""Path to an .npz JAX parameter file with which to initialize the model"""
Lucas Bickmann's avatar
Lucas Bickmann committed
577
    )
Marta's avatar
Marta committed
578
    parser.add_argument(
579
        "--log_performance", type=bool_type, default=False,
Marta's avatar
Marta committed
580
581
        help="Measure performance"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
582
583
    parser.add_argument(
        "--wandb", action="store_true", default=False,
584
        help="Whether to log metrics to Weights & Biases"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
585
586
587
    )
    parser.add_argument(
        "--experiment_name", type=str, default=None,
588
        help="Name of the current experiment. Used for wandb logging"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
589
590
591
    )
    parser.add_argument(
        "--wandb_id", type=str, default=None,
592
        help="ID of a previous run to be resumed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
593
594
595
    )
    parser.add_argument(
        "--wandb_project", type=str, default=None,
596
        help="Name of the wandb project to which this run will belong"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
597
598
599
    )
    parser.add_argument(
        "--wandb_entity", type=str, default=None,
600
        help="wandb username or team name to which runs are attributed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
601
    )
602
603
604
605
    parser.add_argument(
        "--script_modules", type=bool_type, default=False,
        help="Whether to TorchScript eligible components of them model"
    )
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
606
    parser.add_argument(
607
        "--train_chain_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
608
609
    )
    parser.add_argument(
610
        "--distillation_chain_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
611
612
613
    )
    parser.add_argument(
        "--train_epoch_len", type=int, default=10000,
614
615
616
617
618
619
        help=(
            "The virtual length of each training epoch. Stochastic filtering "
            "of training data means that training datasets have no "
            "well-defined length. This virtual length affects frequency of "
            "validation & checkpointing (by default, one of each per epoch)."
        )
620
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
621
    parser.add_argument(
622
623
        "--log_lr", action="store_true", default=False,
        help="Whether to log the actual learning rate"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
624
    )
625
    parser.add_argument(
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
        "--config_preset", type=str, default="initial_training",
        help=(
            'Config setting. Choose e.g. "initial_training", "finetuning", '
            '"model_1", etc. By default, the actual values in the config are '
            'used.'
        )
    )
    parser.add_argument(
        "--_distillation_structure_index_path", type=str, default=None,
    )
    parser.add_argument(
        "--alignment_index_path", type=str, default=None,
        help="Training alignment index. See the README for instructions."
    )
    parser.add_argument(
        "--distillation_alignment_index_path", type=str, default=None,
        help="Distillation alignment index. See the README for instructions."
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
643
    )
644
645
646
    parser.add_argument(
        "--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
    )
Jennifer Wei's avatar
Jennifer Wei committed
647
648
649
650
651
652
653
654
655
656
    parser.add_argument("--num_nodes", type=int, default=1)
    parser.add_argument("--gpus", type=int, default=None)
    parser.add_argument("--max_epochs", type=int, default=None)
    parser.add_argument("--precision", type=str, default="32")
    parser.add_argument("--log_every_n_steps", type=int, default=50)
    parser.add_argument("--accumulate_grad_batches", type=int, default=1)
    parser.add_argument("--flush_logs_every_n_steps", type=int, default=5)
    parser.add_argument("--num_sanity_val_steps", type=int, default=0)
    parser.add_argument("--mpi_plugin", action="store_true", default=False)
    # parser = pl.Trainer.add_argparse_args(parser)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
657
658
   
    # Disable the initial validation pass
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
659
660
661
662
    parser.set_defaults(
        num_sanity_val_steps=0,
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
663
    # Remove some buggy/redundant arguments introduced by the Trainer
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
664
665
666
667
668
    remove_arguments(
        parser, 
        [
            "--accelerator", 
            "--resume_from_checkpoint",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
669
670
            "--reload_dataloaders_every_epoch",
            "--reload_dataloaders_every_n_epochs",
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
671
672
        ]
    ) 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
673

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
674
675
    args = parser.parse_args()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
676
677
678
679
680
    if(args.seed is None and 
        ((args.gpus is not None and args.gpus > 1) or 
         (args.num_nodes is not None and args.num_nodes > 1))):
        raise ValueError("For distributed training, --seed must be specified")

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
681
    if(str(args.precision) == "16" and args.deepspeed_config_path is not None):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
682
683
        raise ValueError("DeepSpeed and FP16 training are not compatible")

Lucas Bickmann's avatar
Lucas Bickmann committed
684
    if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
685
686
        raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path")

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
687
    # This re-applies the training-time filters at the beginning of every epoch
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
688
    args.reload_dataloaders_every_n_epochs = 1
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
689

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
690
    main(args)