train_openfold.py 24.2 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
import argparse
import logging
import os
4
import sys
5
import json
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
6
7

import pytorch_lightning as pl
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
8
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
Jennifer's avatar
Jennifer committed
9
from pytorch_lightning.callbacks import DeviceStatsMonitor
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
10
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
11
from pytorch_lightning.loggers import WandbLogger
Jennifer Wei's avatar
Jennifer Wei committed
12
13
14
from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy
from pytorch_lightning.plugins.environments import MPIEnvironment
from pytorch_lightning import seed_everything
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
15
import torch
Jennifer Wei's avatar
Jennifer Wei committed
16
import wandb
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
17
18

from openfold.config import model_config
19
from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
20
from openfold.model.model import AlphaFold
21
from openfold.model.torchscript import script_preset_
22
from openfold.np import residue_constants
23
from openfold.utils.argparse_utils import remove_arguments
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
24
25
26
from openfold.utils.callbacks import (
    EarlyStoppingVerbose,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
27
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
28
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
29
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
30
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
31
from openfold.utils.superimposition import superimpose
32
from openfold.utils.tensor_utils import tensor_tree_map
33
34
35
36
37
from openfold.utils.validation_metrics import (
    drmsd,
    gdt_ts,
    gdt_ha,
)
38
39
from openfold.utils.import_weights import (
    import_jax_weights_,
40
    import_openfold_weights_
41
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
42
from scripts.zero_to_fp32 import (
43
44
    get_fp32_state_dict_from_zero_checkpoint,
    get_global_step_from_zero_checkpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
45
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
46

Marta's avatar
Marta committed
47
48
from openfold.utils.logger import PerformanceLoggingCallback

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
49
50
51
52
53

class OpenFoldWrapper(pl.LightningModule):
    def __init__(self, config):
        super(OpenFoldWrapper, self).__init__()
        self.config = config
54
        self.model = AlphaFold(config)
55
        self.is_multimer = self.config.globals.is_multimer
56

57
        self.loss = AlphaFoldLoss(config.loss)
58

59
60
61
        self.ema = ExponentialMovingAverage(
            model=self.model, decay=config.ema.decay
        )
62
63
        
        self.cached_weights = None
64
        self.last_lr_step = -1
Jennifer Wei's avatar
Jennifer Wei committed
65
        self.save_hyperparameters()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
66
67
68
69

    def forward(self, batch):
        return self.model(batch)

70
71
72
73
74
75
    def _log(self, loss_breakdown, batch, outputs, train=True):
        phase = "train" if train else "val"
        for loss_name, indiv_loss in loss_breakdown.items():
            self.log(
                f"{phase}/{loss_name}", 
                indiv_loss, 
Jennifer Wei's avatar
Jennifer Wei committed
76
                prog_bar=(loss_name == 'loss'),
77
78
                # on_step=train, on_epoch=(not train), logger=True, sync_dist=False,
                on_step=train, on_epoch=(not train), logger=True, sync_dist=True,
79
80
81
82
83
84
            )

            if(train):
                self.log(
                    f"{phase}/{loss_name}_epoch",
                    indiv_loss,
85
86
                    # on_step=False, on_epoch=True, logger=True, sync_dist=False,
                    on_step=False, on_epoch=True, logger=True, sync_dist=True,
87
88
89
90
91
92
93
94
95
96
97
                )

        with torch.no_grad():
            other_metrics = self._compute_validation_metrics(
                batch, 
                outputs,
                superimposition_metrics=(not train)
            )

        for k,v in other_metrics.items():
            self.log(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
98
99
                f"{phase}/{k}",
                torch.mean(v),
Jennifer Wei's avatar
Jennifer Wei committed
100
                prog_bar = (k == 'loss'),
101
102
                # on_step=False, on_epoch=True, logger=True, sync_dist=False,
                on_step=False, on_epoch=True, logger=True, sync_dist=True,
103
104
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
105
    def training_step(self, batch, batch_idx):
106
107
108
        if(self.ema.device != batch["aatype"].device):
            self.ema.to(batch["aatype"].device)

109
110
        ground_truth = batch.pop('gt_features', None)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
111
112
        # Run the model
        outputs = self(batch)
113

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
114
115
116
        # Remove the recycling dimension
        batch = tensor_tree_map(lambda t: t[..., -1], batch)

117
118
119
120
121
        if self.is_multimer:
            batch = multi_chain_permutation_align(out=outputs,
                                                  features=batch,
                                                  ground_truth=ground_truth)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
122
        # Compute loss
123
124
125
        loss, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
126

127
128
        # Log it
        self._log(loss_breakdown, batch, outputs)
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
129

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
130
        return loss
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
131

132
133
    def on_before_zero_grad(self, *args, **kwargs):
        self.ema.update(self.model)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
134

135
136
137
    def validation_step(self, batch, batch_idx):
        # At the start of validation, load the EMA weights
        if(self.cached_weights is None):
138
139
140
141
142
            # model.state_dict() contains references to model weights rather
            # than copies. Therefore, we need to clone them before calling 
            # load_state_dict().
            clone_param = lambda t: t.detach().clone()
            self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
143
            self.model.load_state_dict(self.ema.state_dict()["params"])
144
145
146

        ground_truth = batch.pop('gt_features', None)

147
        # Run the model
148
149
        outputs = self(batch)
        batch = tensor_tree_map(lambda t: t[..., -1], batch)
150
151

        batch["use_clamped_fape"] = 0.
152
153
154
155
156
157
158

        if self.is_multimer:
            batch = multi_chain_permutation_align(out=outputs,
                                                  features=batch,
                                                  ground_truth=ground_truth)

        # Compute loss and other metrics
159
160
        _, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
161
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
162

163
164
        self._log(loss_breakdown, batch, outputs, train=False)
        
Jennifer Wei's avatar
Jennifer Wei committed
165
    def on_validation_epoch_end(self):
166
167
168
        # Restore the model weights to normal
        self.model.load_state_dict(self.cached_weights)
        self.cached_weights = None
169

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    def _compute_validation_metrics(self, 
        batch, 
        outputs, 
        superimposition_metrics=False
    ):
        metrics = {}
        
        gt_coords = batch["all_atom_positions"]
        pred_coords = outputs["final_atom_positions"]
        all_atom_mask = batch["all_atom_mask"]
    
        # This is super janky for superimposition. Fix later
        gt_coords_masked = gt_coords * all_atom_mask[..., None]
        pred_coords_masked = pred_coords * all_atom_mask[..., None]
        ca_pos = residue_constants.atom_order["CA"]
        gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
        pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
        all_atom_mask_ca = all_atom_mask[..., ca_pos]
    
        lddt_ca_score = lddt_ca(
            pred_coords,
            gt_coords,
            all_atom_mask,
            eps=self.config.globals.eps,
            per_residue=False,
        )
   
        metrics["lddt_ca"] = lddt_ca_score
   
        drmsd_ca_score = drmsd(
            pred_coords_masked_ca,
            gt_coords_masked_ca,
            mask=all_atom_mask_ca, # still required here to compute n
        )
   
        metrics["drmsd_ca"] = drmsd_ca_score
    
        if(superimposition_metrics):
            superimposed_pred, alignment_rmsd = superimpose(
                gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
            )
            gdt_ts_score = gdt_ts(
                superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
            )
            gdt_ha_score = gdt_ha(
                superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
            )

            metrics["alignment_rmsd"] = alignment_rmsd
            metrics["gdt_ts"] = gdt_ts_score
            metrics["gdt_ha"] = gdt_ha_score
    
        return metrics

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
224
225
    def configure_optimizers(self, 
        learning_rate: float = 1e-3,
226
        eps: float = 1e-5,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
227
    ) -> torch.optim.Adam:
228
229
230
231
232
#        return torch.optim.Adam(
#            self.model.parameters(),
#            lr=learning_rate,
#            eps=eps
#        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
233
        # Ignored as long as a DeepSpeed optimizer is configured
234
        optimizer = torch.optim.Adam(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
235
236
237
238
            self.model.parameters(), 
            lr=learning_rate, 
            eps=eps
        )
239
240
241
242
243
244

        if self.last_lr_step != -1:
            for group in optimizer.param_groups:
                if 'initial_lr' not in group:
                    group['initial_lr'] = learning_rate

245
246
        lr_scheduler = AlphaFoldLRScheduler(
            optimizer,
247
            last_epoch=self.last_lr_step
248
        )
249

250
251
252
253
254
255
256
257
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "step",
                "name": "AlphaFoldLRScheduler",
            }
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
258

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
259
    def on_load_checkpoint(self, checkpoint):
260
261
262
263
        ema = checkpoint["ema"]
        if(not self.model.template_config.enabled):
            ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k}
        self.ema.load_state_dict(ema)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
264

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
265
266
267
    def on_save_checkpoint(self, checkpoint):
        checkpoint["ema"] = self.ema.state_dict()

268
269
270
    def resume_last_lr_step(self, lr_step):
        self.last_lr_step = lr_step

271
272
273
274
275
276
277
278
279
280
281
    def load_from_jax(self, jax_path):
        model_basename = os.path.splitext(
                os.path.basename(
                    os.path.normpath(jax_path)
                )
        )[0]
        model_version = "_".join(model_basename.split("_")[1:])
        import_jax_weights_(
                self.model, jax_path, version=model_version
        )

282

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
283
def main(args):
284
285
286
    if(args.seed is not None):
        seed_everything(args.seed) 

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
287
    config = model_config(
288
        args.config_preset, 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
289
        train=True, 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
290
        low_prec=(str(args.precision) == "16")
291
    ) 
292
293
294
295
296
    if args.experiment_config_json: 
        with open(args.experiment_config_json, 'r') as f:
            custom_config_dict = json.load(f)
        config.update_from_flattened_dict(custom_config_dict)

297
298
    model_module = OpenFoldWrapper(config)

299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
    if args.resume_from_ckpt:
        if args.resume_model_weights_only:
            # Load the checkpoint
            if os.path.isdir(args.resume_from_ckpt):
                sd = get_fp32_state_dict_from_zero_checkpoint(
                    args.resume_from_ckpt)
            else:
                sd = torch.load(args.resume_from_ckpt)
            # Process the state dict
            if 'module' in sd:
                sd = {k[len('module.'):]: v for k, v in sd['module'].items()}
                import_openfold_weights_(model=model_module, state_dict=sd)
            elif 'state_dict' in sd:
                import_openfold_weights_(
                    model=model_module, state_dict=sd['state_dict'])
            else:
                # Loading from pre-trained model
                sd = {'model.'+k: v for k, v in sd.items()}
                import_openfold_weights_(model=model_module, state_dict=sd)
            logging.info("Successfully loaded model weights...")

        else:  # Loads a checkpoint to start from a specific time step
            if os.path.isdir(args.resume_from_ckpt):
                last_global_step = get_global_step_from_zero_checkpoint(
                    args.resume_from_ckpt)
            else:
                sd = torch.load(args.resume_from_ckpt)
                last_global_step = int(sd['global_step'])
            model_module.resume_last_lr_step(last_global_step)
            logging.info("Successfully loaded last lr step...")

    if args.resume_from_jax_params:
Lucas Bickmann's avatar
Lucas Bickmann committed
331
332
        model_module.load_from_jax(args.resume_from_jax_params)
        logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
333
 
334
    # TorchScript components of the model
335
336
    if(args.script_modules):
        script_preset_(model_module)
337

338
339
    if "multimer" in args.config_preset:
        data_module = OpenFoldMultimerDataModule(
340
341
342
343
        config=config.data, 
        batch_seed=args.seed,
        **vars(args)
    )
344
345
346
347
348
349
    else:
        data_module = OpenFoldDataModule(
            config=config.data, 
            batch_seed=args.seed,
            **vars(args)
        )
350

351
352
    data_module.prepare_data()
    data_module.setup()
353
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
354
    callbacks = []
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
355
    if(args.checkpoint_every_epoch):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
356
        mc = ModelCheckpoint(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
357
            every_n_epochs=1,
358
359
            auto_insert_metric_name=False,
            save_top_k=-1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
360
361
362
363
364
        )
        callbacks.append(mc)

    if(args.early_stopping):
        es = EarlyStoppingVerbose(
365
            monitor="val/lddt_ca",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
366
367
368
            min_delta=args.min_delta,
            patience=args.patience,
            verbose=False,
369
            mode="max",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
370
371
372
373
            check_finite=True,
            strict=True,
        )
        callbacks.append(es)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
374

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
375
    if(args.log_performance):
Marta's avatar
Marta committed
376
377
        global_batch_size = args.num_nodes * args.gpus
        perf = PerformanceLoggingCallback(
Marta's avatar
Marta committed
378
            log_file=os.path.join(args.output_dir, "performance_log.json"),
Marta's avatar
Marta committed
379
380
381
            global_batch_size=global_batch_size,
        )
        callbacks.append(perf)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
382

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
383
384
385
386
    if(args.log_lr):
        lr_monitor = LearningRateMonitor(logging_interval="step")
        callbacks.append(lr_monitor)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
387
    loggers = []
388
    is_rank_zero = args.mpi_plugin and (int(os.environ.get("PMI_RANK")) == 0)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
389
    if(args.wandb):
Jennifer Wei's avatar
Jennifer Wei committed
390
391
392
393
394
395
396
397
398
399
400
401
        if args.mpi_plugin and is_rank_zero:
            wandb_init_dict = dict(
                name=args.experiment_name,
                project=args.wandb_project,
                id=args.wandb_id,
                dir=args.output_dir,
                resume="allow",
                anonymous=None,
                entity=args.wandb_entity
            )
            wandb.run = wandb.init(**wandb_init_dict)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
402
403
404
405
406
407
408
409
410
        wdb_logger = WandbLogger(
            name=args.experiment_name,
            save_dir=args.output_dir,
            id=args.wandb_id,
            project=args.wandb_project,
            **{"entity": args.wandb_entity}
        )
        loggers.append(wdb_logger)

Jennifer Wei's avatar
Jennifer Wei committed
411
    cluster_environment = MPIEnvironment() if args.mpi_plugin else None
412
    if(args.deepspeed_config_path is not None):
Jennifer Wei's avatar
Jennifer Wei committed
413
        strategy = DeepSpeedStrategy(
414
            config=args.deepspeed_config_path,
Jennifer Wei's avatar
Jennifer Wei committed
415
            cluster_environment=cluster_environment,
416
        )
Jennifer Wei's avatar
Jennifer Wei committed
417
        if(args.wandb and is_rank_zero):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
418
            wdb_logger.experiment.save(args.deepspeed_config_path)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
419
            wdb_logger.experiment.save("openfold/config.py")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
420
    elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
Jennifer Wei's avatar
Jennifer Wei committed
421
422
        strategy = DDPStrategy(find_unused_parameters=False,
                               cluster_environment=cluster_environment)
423
424
    else:
        strategy = None
425
 
Jennifer Wei's avatar
Jennifer Wei committed
426
    if(args.wandb and is_rank_zero):
427
428
429
430
        freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
        os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
        wdb_logger.experiment.save(f"{freeze_path}")

Jennifer Wei's avatar
Jennifer Wei committed
431
432
433
434
435
    trainer = pl.Trainer(
        num_nodes=args.num_nodes,
        devices=args.gpus,
        precision=args.precision,
        max_epochs=args.max_epochs,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
436
        default_root_dir=args.output_dir,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
437
        strategy=strategy,
Marta's avatar
Marta committed
438
        callbacks=callbacks,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
439
        logger=loggers,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
440
441
    )

Jennifer's avatar
Jennifer committed
442
    if (args.resume_model_weights_only):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
443
444
445
446
447
448
449
450
        ckpt_path = None
    else:
        ckpt_path = args.resume_from_ckpt

    trainer.fit(
        model_module, 
        datamodule=data_module,
        ckpt_path=ckpt_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
451
452
453
    )


Marta's avatar
Marta committed
454
455
456
457
458
459
460
461
462
463
def bool_type(bool_str: str):
    bool_str_lower = bool_str.lower()
    if bool_str_lower in ('false', 'f', 'no', 'n', '0'):
        return False
    elif bool_str_lower in ('true', 't', 'yes', 'y', '1'):
        return True
    else:
        raise ValueError(f'Cannot interpret {bool_str} as bool')


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
464
465
466
467
468
469
470
471
472
473
474
475
476
477
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "train_data_dir", type=str,
        help="Directory containing training mmCIF files"
    )
    parser.add_argument(
        "train_alignment_dir", type=str,
        help="Directory containing precomputed training alignments"
    )
    parser.add_argument(
        "template_mmcif_dir", type=str,
        help="Directory containing mmCIF files to search for templates"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
478
479
480
481
482
    parser.add_argument(
        "output_dir", type=str,
        help='''Directory in which to output checkpoints, logs, etc. Ignored
                if not on rank 0'''
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
483
484
    parser.add_argument(
        "max_template_date", type=str,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
485
486
        help='''Cutoff for all templates. In training mode, templates are also 
                filtered by the release date of the target'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
487
    )
488
489
    parser.add_argument(
        "--train_mmcif_data_cache_path", type=str, default=None,
490
491
        help="Path to the json file which records all the information of mmcif structures used during training"
    )
492
    parser.add_argument(
493
        "--use_single_seq_mode", type=str, default=False,
494
        help="Use single sequence embeddings instead of MSAs."
495
    )
496
497
498
499
500
501
502
503
    parser.add_argument(
        "--distillation_data_dir", type=str, default=None,
        help="Directory containing training PDB files"
    )
    parser.add_argument(
        "--distillation_alignment_dir", type=str, default=None,
        help="Directory containing precomputed distillation alignments"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
504
505
506
507
508
509
510
511
    parser.add_argument(
        "--val_data_dir", type=str, default=None,
        help="Directory containing validation mmCIF files"
    )
    parser.add_argument(
        "--val_alignment_dir", type=str, default=None,
        help="Directory containing precomputed validation alignments"
    )
512
513
    parser.add_argument(
        "--val_mmcif_data_cache_path", type=str, default=None,
Dingquan Yu's avatar
Dingquan Yu committed
514
        help="path to the json file which records all the information of mmcif structures used during validation"
515
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
516
517
518
519
520
    parser.add_argument(
        "--kalign_binary_path", type=str, default='/usr/bin/kalign',
        help="Path to the kalign binary"
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
521
522
523
524
        "--train_filter_path", type=str, default=None,
        help='''Optional path to a text file containing names of training
                examples to include, one per line. Used to filter the training 
                set'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
525
526
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
527
528
        "--distillation_filter_path", type=str, default=None,
        help="""See --train_filter_path"""
529
    )
530
531
532
533
534
    parser.add_argument(
        "--obsolete_pdbs_file_path", type=str, default=None,
        help="""Path to obsolete.dat file containing list of obsolete PDBs and 
             their replacements."""
    )
535
536
    parser.add_argument(
        "--template_release_dates_cache_path", type=str, default=None,
537
538
        help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
                files."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
539
540
    )
    parser.add_argument(
Marta's avatar
Marta committed
541
        "--use_small_bfd", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
542
543
544
        help="Whether to use a reduced version of the BFD database"
    )
    parser.add_argument(
545
546
        "--seed", type=int, default=None,
        help="Random seed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
547
    )
548
549
550
551
    parser.add_argument(
        "--deepspeed_config_path", type=str, default=None,
        help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
552
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
553
554
        "--checkpoint_every_epoch", action="store_true", default=False,
        help="""Whether to checkpoint at the end of every training epoch"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
555
556
    )
    parser.add_argument(
Marta's avatar
Marta committed
557
        "--early_stopping", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
558
559
560
561
562
563
564
565
566
567
568
        help="Whether to stop training when validation loss fails to decrease"
    )
    parser.add_argument(
        "--min_delta", type=float, default=0,
        help="""The smallest decrease in validation loss that counts as an 
                improvement for the purposes of early stopping"""
    )
    parser.add_argument(
        "--patience", type=int, default=3,
        help="Early stopping patience"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
569
570
571
572
573
    parser.add_argument(
        "--resume_from_ckpt", type=str, default=None,
        help="Path to a model checkpoint from which to restore training state"
    )
    parser.add_argument(
Marta's avatar
Marta committed
574
        "--resume_model_weights_only", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
575
576
        help="Whether to load just model weights as opposed to training state"
    )
Lucas Bickmann's avatar
Lucas Bickmann committed
577
    parser.add_argument(
578
579
        "--resume_from_jax_params", type=str, default=None,
        help="""Path to an .npz JAX parameter file with which to initialize the model"""
Lucas Bickmann's avatar
Lucas Bickmann committed
580
    )
Marta's avatar
Marta committed
581
    parser.add_argument(
582
        "--log_performance", type=bool_type, default=False,
Marta's avatar
Marta committed
583
584
        help="Measure performance"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
585
586
    parser.add_argument(
        "--wandb", action="store_true", default=False,
587
        help="Whether to log metrics to Weights & Biases"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
588
589
590
    )
    parser.add_argument(
        "--experiment_name", type=str, default=None,
591
        help="Name of the current experiment. Used for wandb logging"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
592
593
594
    )
    parser.add_argument(
        "--wandb_id", type=str, default=None,
595
        help="ID of a previous run to be resumed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
596
597
598
    )
    parser.add_argument(
        "--wandb_project", type=str, default=None,
599
        help="Name of the wandb project to which this run will belong"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
600
601
602
    )
    parser.add_argument(
        "--wandb_entity", type=str, default=None,
603
        help="wandb username or team name to which runs are attributed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
604
    )
605
606
607
608
    parser.add_argument(
        "--script_modules", type=bool_type, default=False,
        help="Whether to TorchScript eligible components of them model"
    )
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
609
    parser.add_argument(
610
        "--train_chain_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
611
612
    )
    parser.add_argument(
613
        "--distillation_chain_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
614
615
616
    )
    parser.add_argument(
        "--train_epoch_len", type=int, default=10000,
617
618
619
620
621
622
        help=(
            "The virtual length of each training epoch. Stochastic filtering "
            "of training data means that training datasets have no "
            "well-defined length. This virtual length affects frequency of "
            "validation & checkpointing (by default, one of each per epoch)."
        )
623
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
624
    parser.add_argument(
625
626
        "--log_lr", action="store_true", default=False,
        help="Whether to log the actual learning rate"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
627
    )
628
    parser.add_argument(
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
        "--config_preset", type=str, default="initial_training",
        help=(
            'Config setting. Choose e.g. "initial_training", "finetuning", '
            '"model_1", etc. By default, the actual values in the config are '
            'used.'
        )
    )
    parser.add_argument(
        "--_distillation_structure_index_path", type=str, default=None,
    )
    parser.add_argument(
        "--alignment_index_path", type=str, default=None,
        help="Training alignment index. See the README for instructions."
    )
    parser.add_argument(
        "--distillation_alignment_index_path", type=str, default=None,
        help="Distillation alignment index. See the README for instructions."
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
646
    )
647
648
649
    parser.add_argument(
        "--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
    )
Jennifer Wei's avatar
Jennifer Wei committed
650
651
652
653
654
655
656
657
658
659
    parser.add_argument("--num_nodes", type=int, default=1)
    parser.add_argument("--gpus", type=int, default=None)
    parser.add_argument("--max_epochs", type=int, default=None)
    parser.add_argument("--precision", type=str, default="32")
    parser.add_argument("--log_every_n_steps", type=int, default=50)
    parser.add_argument("--accumulate_grad_batches", type=int, default=1)
    parser.add_argument("--flush_logs_every_n_steps", type=int, default=5)
    parser.add_argument("--num_sanity_val_steps", type=int, default=0)
    parser.add_argument("--mpi_plugin", action="store_true", default=False)
    # parser = pl.Trainer.add_argparse_args(parser)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
660
661
   
    # Disable the initial validation pass
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
662
663
664
665
    parser.set_defaults(
        num_sanity_val_steps=0,
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
666
    # Remove some buggy/redundant arguments introduced by the Trainer
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
667
668
669
670
671
    remove_arguments(
        parser, 
        [
            "--accelerator", 
            "--resume_from_checkpoint",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
672
673
            "--reload_dataloaders_every_epoch",
            "--reload_dataloaders_every_n_epochs",
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
674
675
        ]
    ) 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
676

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
677
678
    args = parser.parse_args()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
679
680
681
682
683
    if(args.seed is None and 
        ((args.gpus is not None and args.gpus > 1) or 
         (args.num_nodes is not None and args.num_nodes > 1))):
        raise ValueError("For distributed training, --seed must be specified")

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
684
    if(str(args.precision) == "16" and args.deepspeed_config_path is not None):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
685
686
        raise ValueError("DeepSpeed and FP16 training are not compatible")

Lucas Bickmann's avatar
Lucas Bickmann committed
687
    if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
688
689
        raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path")

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
690
    # This re-applies the training-time filters at the beginning of every epoch
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
691
    args.reload_dataloaders_every_n_epochs = 1
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
692

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
693
    main(args)