"deploy/compoundai/api-server/go.sum" did not exist on "14ce7e03bedde2b7632fa526a01a187f4b374996"
train_openfold.py 24.8 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
import argparse
import logging
import os
4
import sys
5
import json
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
6
7

import pytorch_lightning as pl
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
8
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
Jennifer's avatar
Jennifer committed
9
from pytorch_lightning.callbacks import DeviceStatsMonitor
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
10
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
11
from pytorch_lightning.loggers import WandbLogger
Jennifer Wei's avatar
Jennifer Wei committed
12
13
14
from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy
from pytorch_lightning.plugins.environments import MPIEnvironment
from pytorch_lightning import seed_everything
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
15
import torch
Jennifer Wei's avatar
Jennifer Wei committed
16
import wandb
17
from deepspeed.utils import zero_to_fp32 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
18
19

from openfold.config import model_config
20
from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
21
from openfold.model.model import AlphaFold
22
from openfold.model.torchscript import script_preset_
23
from openfold.np import residue_constants
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
24
25
26
from openfold.utils.callbacks import (
    EarlyStoppingVerbose,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
27
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
28
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
29
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
30
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
31
from openfold.utils.superimposition import superimpose
32
from openfold.utils.tensor_utils import tensor_tree_map
33
34
35
36
37
from openfold.utils.validation_metrics import (
    drmsd,
    gdt_ts,
    gdt_ha,
)
38
39
from openfold.utils.import_weights import (
    import_jax_weights_,
40
    import_openfold_weights_
41
)
Marta's avatar
Marta committed
42
43
from openfold.utils.logger import PerformanceLoggingCallback

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
44
45
46
47
48

class OpenFoldWrapper(pl.LightningModule):
    def __init__(self, config):
        super(OpenFoldWrapper, self).__init__()
        self.config = config
49
        self.model = AlphaFold(config)
50
        self.is_multimer = self.config.globals.is_multimer
51

52
        self.loss = AlphaFoldLoss(config.loss)
53

54
55
56
        self.ema = ExponentialMovingAverage(
            model=self.model, decay=config.ema.decay
        )
57

58
        self.cached_weights = None
59
        self.last_lr_step = -1
Jennifer Wei's avatar
Jennifer Wei committed
60
        self.save_hyperparameters()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
61
62
63
64

    def forward(self, batch):
        return self.model(batch)

65
66
67
68
69
70
    def _log(self, loss_breakdown, batch, outputs, train=True):
        phase = "train" if train else "val"
        for loss_name, indiv_loss in loss_breakdown.items():
            self.log(
                f"{phase}/{loss_name}", 
                indiv_loss, 
Jennifer Wei's avatar
Jennifer Wei committed
71
72
                prog_bar=(loss_name == 'loss'),
                on_step=train, on_epoch=(not train), logger=True, sync_dist=False,
73
74
            )

75
            if (train):
76
77
78
                self.log(
                    f"{phase}/{loss_name}_epoch",
                    indiv_loss,
Jennifer Wei's avatar
Jennifer Wei committed
79
                    on_step=False, on_epoch=True, logger=True, sync_dist=False,
80
81
82
83
                )

        with torch.no_grad():
            other_metrics = self._compute_validation_metrics(
84
                batch,
85
86
87
88
                outputs,
                superimposition_metrics=(not train)
            )

89
        for k, v in other_metrics.items():
90
            self.log(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
91
92
                f"{phase}/{k}",
                torch.mean(v),
Jennifer Wei's avatar
Jennifer Wei committed
93
94
                prog_bar = (k == 'loss'),
                on_step=False, on_epoch=True, logger=True, sync_dist=False,
95
96
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
97
    def training_step(self, batch, batch_idx):
98
        if (self.ema.device != batch["aatype"].device):
99
100
            self.ema.to(batch["aatype"].device)

101
102
        ground_truth = batch.pop('gt_features', None)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
103
104
        # Run the model
        outputs = self(batch)
105

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
106
107
108
        # Remove the recycling dimension
        batch = tensor_tree_map(lambda t: t[..., -1], batch)

109
110
111
112
113
        if self.is_multimer:
            batch = multi_chain_permutation_align(out=outputs,
                                                  features=batch,
                                                  ground_truth=ground_truth)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
114
        # Compute loss
115
116
117
        loss, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
118

119
120
        # Log it
        self._log(loss_breakdown, batch, outputs)
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
121

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
122
        return loss
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
123

124
125
    def on_before_zero_grad(self, *args, **kwargs):
        self.ema.update(self.model)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
126

127
128
    def validation_step(self, batch, batch_idx):
        # At the start of validation, load the EMA weights
129
        if (self.cached_weights is None):
130
            # model.state_dict() contains references to model weights rather
131
            # than copies. Therefore, we need to clone them before calling
132
            # load_state_dict().
133
134
135
            def clone_param(t): return t.detach().clone()
            self.cached_weights = tensor_tree_map(
                clone_param, self.model.state_dict())
136
            self.model.load_state_dict(self.ema.state_dict()["params"])
137
138
139

        ground_truth = batch.pop('gt_features', None)

140
        # Run the model
141
142
        outputs = self(batch)
        batch = tensor_tree_map(lambda t: t[..., -1], batch)
143
144

        batch["use_clamped_fape"] = 0.
145
146
147
148
149
150
151

        if self.is_multimer:
            batch = multi_chain_permutation_align(out=outputs,
                                                  features=batch,
                                                  ground_truth=ground_truth)

        # Compute loss and other metrics
152
153
        _, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
154
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
155

156
157
        self._log(loss_breakdown, batch, outputs, train=False)
        
Jennifer Wei's avatar
Jennifer Wei committed
158
    def on_validation_epoch_end(self):
159
160
161
        # Restore the model weights to normal
        self.model.load_state_dict(self.cached_weights)
        self.cached_weights = None
162

163
164
165
166
167
    def _compute_validation_metrics(self,
                                    batch,
                                    outputs,
                                    superimposition_metrics=False
                                    ):
168
        metrics = {}
169

170
171
172
        gt_coords = batch["all_atom_positions"]
        pred_coords = outputs["final_atom_positions"]
        all_atom_mask = batch["all_atom_mask"]
173

174
175
176
177
178
179
180
        # This is super janky for superimposition. Fix later
        gt_coords_masked = gt_coords * all_atom_mask[..., None]
        pred_coords_masked = pred_coords * all_atom_mask[..., None]
        ca_pos = residue_constants.atom_order["CA"]
        gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
        pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
        all_atom_mask_ca = all_atom_mask[..., ca_pos]
181

182
183
184
185
186
187
188
        lddt_ca_score = lddt_ca(
            pred_coords,
            gt_coords,
            all_atom_mask,
            eps=self.config.globals.eps,
            per_residue=False,
        )
189

190
        metrics["lddt_ca"] = lddt_ca_score
191

192
193
194
        drmsd_ca_score = drmsd(
            pred_coords_masked_ca,
            gt_coords_masked_ca,
195
            mask=all_atom_mask_ca,  # still required here to compute n
196
        )
197

198
        metrics["drmsd_ca"] = drmsd_ca_score
199
200

        if (superimposition_metrics):
201
202
203
204
205
206
207
208
209
210
211
212
213
            superimposed_pred, alignment_rmsd = superimpose(
                gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
            )
            gdt_ts_score = gdt_ts(
                superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
            )
            gdt_ha_score = gdt_ha(
                superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
            )

            metrics["alignment_rmsd"] = alignment_rmsd
            metrics["gdt_ts"] = gdt_ts_score
            metrics["gdt_ha"] = gdt_ha_score
214

215
216
        return metrics

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
217
218
    def configure_optimizers(self, 
        learning_rate: float = 1e-3,
219
        eps: float = 1e-5,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
220
221
    ) -> torch.optim.Adam:
        # Ignored as long as a DeepSpeed optimizer is configured
222
        optimizer = torch.optim.Adam(
223
224
            self.model.parameters(),
            lr=learning_rate,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
225
226
            eps=eps
        )
227
228
229
230
231
232

        if self.last_lr_step != -1:
            for group in optimizer.param_groups:
                if 'initial_lr' not in group:
                    group['initial_lr'] = learning_rate

233
234
        lr_scheduler = AlphaFoldLRScheduler(
            optimizer,
235
            last_epoch=self.last_lr_step
236
        )
237

238
239
240
241
242
243
244
245
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "step",
                "name": "AlphaFoldLRScheduler",
            }
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
246

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
247
    def on_load_checkpoint(self, checkpoint):
248
        ema = checkpoint["ema"]
249
250
251
        if (not self.model.template_config.enabled):
            ema["params"] = {k: v for k,
                             v in ema["params"].items() if not "template" in k}
252
        self.ema.load_state_dict(ema)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
253

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
254
255
256
    def on_save_checkpoint(self, checkpoint):
        checkpoint["ema"] = self.ema.state_dict()

257
258
259
    def resume_last_lr_step(self, lr_step):
        self.last_lr_step = lr_step

260
261
    def load_from_jax(self, jax_path):
        model_basename = os.path.splitext(
262
263
264
            os.path.basename(
                os.path.normpath(jax_path)
            )
265
266
267
        )[0]
        model_version = "_".join(model_basename.split("_")[1:])
        import_jax_weights_(
268
            self.model, jax_path, version=model_version
269
270
        )

271
272
273
274
275
276
277
278
279
280
281
282
def get_model_state_dict_from_ds_checkpoint(checkpoint_dir):
    latest_path = os.path.join(checkpoint_dir, 'latest')
    if os.path.isfile(latest_path):
        with open(latest_path, 'r') as fd:
            tag = fd.read().strip()
    else:
        raise ValueError(f"Unable to find 'latest' file at {latest_path}")

    ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
    _DS_CHECKPOINT_VERSION = 2  # based on manual parsing of checkpoint files
    state_file = zero_to_fp32.get_model_state_file(ds_checkpoint_dir, _DS_CHECKPOINT_VERSION)
    return torch.load(state_file)
283

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
284
def main(args):
285
    if(args.seed is not None):
286
        seed_everything(args.seed, workers=True) 
287

288
289
290
    is_low_precision = args.precision in [
        "bf16-mixed", "16", "bf16", "16-true", "16-mixed", "bf16-mixed"]

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
291
    config = model_config(
292
        args.config_preset, 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
293
        train=True, 
294
        low_prec=is_low_precision,
295
    ) 
296
297
298
299
300
    if args.experiment_config_json: 
        with open(args.experiment_config_json, 'r') as f:
            custom_config_dict = json.load(f)
        config.update_from_flattened_dict(custom_config_dict)

301
302
    model_module = OpenFoldWrapper(config)

303
304
305
306
    if args.resume_from_ckpt:
        if args.resume_model_weights_only:
            # Load the checkpoint
            if os.path.isdir(args.resume_from_ckpt):
307
                sd = zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
                    args.resume_from_ckpt)
            else:
                sd = torch.load(args.resume_from_ckpt)
            # Process the state dict
            if 'module' in sd:
                sd = {k[len('module.'):]: v for k, v in sd['module'].items()}
                import_openfold_weights_(model=model_module, state_dict=sd)
            elif 'state_dict' in sd:
                import_openfold_weights_(
                    model=model_module, state_dict=sd['state_dict'])
            else:
                # Loading from pre-trained model
                sd = {'model.'+k: v for k, v in sd.items()}
                import_openfold_weights_(model=model_module, state_dict=sd)
            logging.info("Successfully loaded model weights...")

        else:  # Loads a checkpoint to start from a specific time step
            if os.path.isdir(args.resume_from_ckpt):
326
                sd = get_model_state_dict_from_ds_checkpoint(args.resume_from_ckpt)
327
328
            else:
                sd = torch.load(args.resume_from_ckpt)
329
            last_global_step = int(sd['global_step'])
330
331
332
333
            model_module.resume_last_lr_step(last_global_step)
            logging.info("Successfully loaded last lr step...")

    if args.resume_from_jax_params:
Lucas Bickmann's avatar
Lucas Bickmann committed
334
        model_module.load_from_jax(args.resume_from_jax_params)
335
336
337
        logging.info(
            f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")

338
    # TorchScript components of the model
339
    if (args.script_modules):
340
        script_preset_(model_module)
341

342
343
    if "multimer" in args.config_preset:
        data_module = OpenFoldMultimerDataModule(
344
345
346
347
            config=config.data,
            batch_seed=args.seed,
            **vars(args)
        )
348
349
    else:
        data_module = OpenFoldDataModule(
350
            config=config.data,
351
352
353
            batch_seed=args.seed,
            **vars(args)
        )
354

355
356
    data_module.prepare_data()
    data_module.setup()
357

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
358
    callbacks = []
359
    if (args.checkpoint_every_epoch):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
360
        mc = ModelCheckpoint(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
361
            every_n_epochs=1,
362
363
            auto_insert_metric_name=False,
            save_top_k=-1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
364
365
366
        )
        callbacks.append(mc)

367
    if (args.early_stopping):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
368
        es = EarlyStoppingVerbose(
369
            monitor="val/lddt_ca",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
370
371
372
            min_delta=args.min_delta,
            patience=args.patience,
            verbose=False,
373
            mode="max",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
374
375
376
377
            check_finite=True,
            strict=True,
        )
        callbacks.append(es)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
378

379
    if (args.log_performance):
Marta's avatar
Marta committed
380
381
        global_batch_size = args.num_nodes * args.gpus
        perf = PerformanceLoggingCallback(
Marta's avatar
Marta committed
382
            log_file=os.path.join(args.output_dir, "performance_log.json"),
Marta's avatar
Marta committed
383
384
385
            global_batch_size=global_batch_size,
        )
        callbacks.append(perf)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
386

387
    if (args.log_lr):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
388
389
390
        lr_monitor = LearningRateMonitor(logging_interval="step")
        callbacks.append(lr_monitor)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
391
    loggers = []
392
    is_rank_zero = args.mpi_plugin and (int(os.environ.get("PMI_RANK")) == 0)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
393
    if(args.wandb):
Jennifer Wei's avatar
Jennifer Wei committed
394
395
396
397
398
399
400
401
402
403
404
405
        if args.mpi_plugin and is_rank_zero:
            wandb_init_dict = dict(
                name=args.experiment_name,
                project=args.wandb_project,
                id=args.wandb_id,
                dir=args.output_dir,
                resume="allow",
                anonymous=None,
                entity=args.wandb_entity
            )
            wandb.run = wandb.init(**wandb_init_dict)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
406
407
408
409
410
411
412
413
414
        wdb_logger = WandbLogger(
            name=args.experiment_name,
            save_dir=args.output_dir,
            id=args.wandb_id,
            project=args.wandb_project,
            **{"entity": args.wandb_entity}
        )
        loggers.append(wdb_logger)

Jennifer Wei's avatar
Jennifer Wei committed
415
    cluster_environment = MPIEnvironment() if args.mpi_plugin else None
416
    if(args.deepspeed_config_path is not None):
Jennifer Wei's avatar
Jennifer Wei committed
417
        strategy = DeepSpeedStrategy(
418
            config=args.deepspeed_config_path,
Jennifer Wei's avatar
Jennifer Wei committed
419
            cluster_environment=cluster_environment,
420
        )
Jennifer Wei's avatar
Jennifer Wei committed
421
        if(args.wandb and is_rank_zero):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
422
            wdb_logger.experiment.save(args.deepspeed_config_path)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
423
            wdb_logger.experiment.save("openfold/config.py")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
424
    elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
Jennifer Wei's avatar
Jennifer Wei committed
425
426
        strategy = DDPStrategy(find_unused_parameters=False,
                               cluster_environment=cluster_environment)
427
428
    else:
        strategy = None
429
 
Jennifer Wei's avatar
Jennifer Wei committed
430
    if(args.wandb and is_rank_zero):
431
432
433
434
        freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
        os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
        wdb_logger.experiment.save(f"{freeze_path}")

435
436
437
438
439
440
441
442
443
444
445
    trainer_kws = ['num_nodes', 'precision', 'max_epochs', 'log_every_n_steps',
                   'flush_logs_ever_n_steps', 'num_sanity_val_steps', 'reload_dataloaders_every_n_epochs']
    trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws}
    trainer_args.update({
        'default_root_dir': args.output_dir,
        'strategy': strategy,
        'callbacks': callbacks,
        'logger': loggers,
    })
    trainer = pl.Trainer(**trainer_args)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
446

Jennifer's avatar
Jennifer committed
447
    if (args.resume_model_weights_only):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
448
449
450
451
452
        ckpt_path = None
    else:
        ckpt_path = args.resume_from_ckpt

    trainer.fit(
453
        model_module,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
454
455
        datamodule=data_module,
        ckpt_path=ckpt_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
456
457
458
    )


Marta's avatar
Marta committed
459
460
461
462
463
464
465
466
467
468
def bool_type(bool_str: str):
    bool_str_lower = bool_str.lower()
    if bool_str_lower in ('false', 'f', 'no', 'n', '0'):
        return False
    elif bool_str_lower in ('true', 't', 'yes', 'y', '1'):
        return True
    else:
        raise ValueError(f'Cannot interpret {bool_str} as bool')


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
469
470
471
472
473
474
475
476
477
478
479
480
481
482
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "train_data_dir", type=str,
        help="Directory containing training mmCIF files"
    )
    parser.add_argument(
        "train_alignment_dir", type=str,
        help="Directory containing precomputed training alignments"
    )
    parser.add_argument(
        "template_mmcif_dir", type=str,
        help="Directory containing mmCIF files to search for templates"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
483
484
485
486
487
    parser.add_argument(
        "output_dir", type=str,
        help='''Directory in which to output checkpoints, logs, etc. Ignored
                if not on rank 0'''
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
488
489
    parser.add_argument(
        "max_template_date", type=str,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
490
491
        help='''Cutoff for all templates. In training mode, templates are also 
                filtered by the release date of the target'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
492
    )
493
494
    parser.add_argument(
        "--train_mmcif_data_cache_path", type=str, default=None,
495
496
        help="Path to the json file which records all the information of mmcif structures used during training"
    )
497
    parser.add_argument(
498
        "--use_single_seq_mode", type=str, default=False,
499
        help="Use single sequence embeddings instead of MSAs."
500
    )
501
502
503
504
505
506
507
508
    parser.add_argument(
        "--distillation_data_dir", type=str, default=None,
        help="Directory containing training PDB files"
    )
    parser.add_argument(
        "--distillation_alignment_dir", type=str, default=None,
        help="Directory containing precomputed distillation alignments"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
509
510
511
512
513
514
515
516
    parser.add_argument(
        "--val_data_dir", type=str, default=None,
        help="Directory containing validation mmCIF files"
    )
    parser.add_argument(
        "--val_alignment_dir", type=str, default=None,
        help="Directory containing precomputed validation alignments"
    )
517
518
    parser.add_argument(
        "--val_mmcif_data_cache_path", type=str, default=None,
Dingquan Yu's avatar
Dingquan Yu committed
519
        help="path to the json file which records all the information of mmcif structures used during validation"
520
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
521
522
523
524
525
    parser.add_argument(
        "--kalign_binary_path", type=str, default='/usr/bin/kalign',
        help="Path to the kalign binary"
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
526
527
528
529
        "--train_filter_path", type=str, default=None,
        help='''Optional path to a text file containing names of training
                examples to include, one per line. Used to filter the training 
                set'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
530
531
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
532
533
        "--distillation_filter_path", type=str, default=None,
        help="""See --train_filter_path"""
534
    )
535
536
537
538
539
    parser.add_argument(
        "--obsolete_pdbs_file_path", type=str, default=None,
        help="""Path to obsolete.dat file containing list of obsolete PDBs and 
             their replacements."""
    )
540
541
    parser.add_argument(
        "--template_release_dates_cache_path", type=str, default=None,
542
543
        help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
                files."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
544
545
    )
    parser.add_argument(
Marta's avatar
Marta committed
546
        "--use_small_bfd", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
547
548
549
        help="Whether to use a reduced version of the BFD database"
    )
    parser.add_argument(
550
551
        "--seed", type=int, default=None,
        help="Random seed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
552
    )
553
554
555
556
    parser.add_argument(
        "--deepspeed_config_path", type=str, default=None,
        help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
557
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
558
559
        "--checkpoint_every_epoch", action="store_true", default=False,
        help="""Whether to checkpoint at the end of every training epoch"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
560
561
    )
    parser.add_argument(
Marta's avatar
Marta committed
562
        "--early_stopping", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
563
564
565
566
567
568
569
570
571
572
573
        help="Whether to stop training when validation loss fails to decrease"
    )
    parser.add_argument(
        "--min_delta", type=float, default=0,
        help="""The smallest decrease in validation loss that counts as an 
                improvement for the purposes of early stopping"""
    )
    parser.add_argument(
        "--patience", type=int, default=3,
        help="Early stopping patience"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
574
575
576
577
578
    parser.add_argument(
        "--resume_from_ckpt", type=str, default=None,
        help="Path to a model checkpoint from which to restore training state"
    )
    parser.add_argument(
Marta's avatar
Marta committed
579
        "--resume_model_weights_only", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
580
581
        help="Whether to load just model weights as opposed to training state"
    )
Lucas Bickmann's avatar
Lucas Bickmann committed
582
    parser.add_argument(
583
584
        "--resume_from_jax_params", type=str, default=None,
        help="""Path to an .npz JAX parameter file with which to initialize the model"""
Lucas Bickmann's avatar
Lucas Bickmann committed
585
    )
Marta's avatar
Marta committed
586
    parser.add_argument(
587
        "--log_performance", type=bool_type, default=False,
Marta's avatar
Marta committed
588
589
        help="Measure performance"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
590
591
    parser.add_argument(
        "--wandb", action="store_true", default=False,
592
        help="Whether to log metrics to Weights & Biases"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
593
594
595
    )
    parser.add_argument(
        "--experiment_name", type=str, default=None,
596
        help="Name of the current experiment. Used for wandb logging"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
597
598
599
    )
    parser.add_argument(
        "--wandb_id", type=str, default=None,
600
        help="ID of a previous run to be resumed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
601
602
603
    )
    parser.add_argument(
        "--wandb_project", type=str, default=None,
604
        help="Name of the wandb project to which this run will belong"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
605
606
607
    )
    parser.add_argument(
        "--wandb_entity", type=str, default=None,
608
        help="wandb username or team name to which runs are attributed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
609
    )
610
611
612
613
    parser.add_argument(
        "--script_modules", type=bool_type, default=False,
        help="Whether to TorchScript eligible components of them model"
    )
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
614
    parser.add_argument(
615
        "--train_chain_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
616
617
    )
    parser.add_argument(
618
        "--distillation_chain_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
619
620
621
    )
    parser.add_argument(
        "--train_epoch_len", type=int, default=10000,
622
623
624
625
626
627
        help=(
            "The virtual length of each training epoch. Stochastic filtering "
            "of training data means that training datasets have no "
            "well-defined length. This virtual length affects frequency of "
            "validation & checkpointing (by default, one of each per epoch)."
        )
628
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
629
    parser.add_argument(
630
631
        "--log_lr", action="store_true", default=False,
        help="Whether to log the actual learning rate"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
632
    )
633
    parser.add_argument(
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
        "--config_preset", type=str, default="initial_training",
        help=(
            'Config setting. Choose e.g. "initial_training", "finetuning", '
            '"model_1", etc. By default, the actual values in the config are '
            'used.'
        )
    )
    parser.add_argument(
        "--_distillation_structure_index_path", type=str, default=None,
    )
    parser.add_argument(
        "--alignment_index_path", type=str, default=None,
        help="Training alignment index. See the README for instructions."
    )
    parser.add_argument(
        "--distillation_alignment_index_path", type=str, default=None,
        help="Distillation alignment index. See the README for instructions."
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
651
    )
652
653
654
    parser.add_argument(
        "--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
    )
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
    parser.add_argument(
        "--gpus", type=int, default=1, help='For determining optimal strategy and effective batch size.'
    )
    parser.add_argument("--mpi_plugin", action="store_true", default=False,
                        help="Whether to use MPI for parallele processing")

    trainer_group = parser.add_argument_group(
        'Arguments to pass to PyTorch Lightning Trainer')
    trainer_group.add_argument(
        "--num_nodes", type=int, default=1,
    )
    trainer_group.add_argument(
        "--precision", type=str, default='bf16',
        help='Sets precision, lower precision improves runtime performance.',
    )
    trainer_group.add_argument(
        "--max_epochs", type=int, default=1,
    )
    trainer_group.add_argument(
        "--log_every_n_steps", type=int, default=25,
    )
    trainer_group.add_argument(
        "--flush_logs_every_n_steps", type=int, default=5,
    )
    trainer_group.add_argument(
        "--num_sanity_val_steps", type=int, default=0,
    )
    trainer_group.add_argument(
        "--reload_dataloaders_every_n_epochs", type=int, default=1,
    )

    trainer_group.add_argument("--accumulate_grad_batches", type=int, default=1,
                               help="Accumulate gradients over k batches before next optimizer step.")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
688

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
689
690
    args = parser.parse_args()

691
692
    if (args.seed is None and
        ((args.gpus is not None and args.gpus > 1) or
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
693
694
695
         (args.num_nodes is not None and args.num_nodes > 1))):
        raise ValueError("For distributed training, --seed must be specified")

696
    if (str(args.precision) == "16" and args.deepspeed_config_path is not None):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
697
698
        raise ValueError("DeepSpeed and FP16 training are not compatible")

699
700
701
    if (args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
        raise ValueError(
            "Choose between loading pretrained Jax-weights and a checkpoint-path")
702

703

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
704
    main(args)