train_openfold.py 23.7 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
import argparse
import logging
import os
4
import sys
5
import json
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
6
7

import pytorch_lightning as pl
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
8
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
9
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
10
from pytorch_lightning.loggers import WandbLogger
11
from pytorch_lightning.strategies import DeepSpeedStrategy, DDPStrategy
12
from pytorch_lightning import seed_everything
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
13
import torch
14
from deepspeed.utils import zero_to_fp32 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
15
16

from openfold.config import model_config
17
from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
18
from openfold.model.model import AlphaFold
19
from openfold.model.torchscript import script_preset_
20
from openfold.np import residue_constants
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
21
22
23
from openfold.utils.callbacks import (
    EarlyStoppingVerbose,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
24
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
25
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
26
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
27
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
28
from openfold.utils.superimposition import superimpose
29
from openfold.utils.tensor_utils import tensor_tree_map
30
31
32
33
34
from openfold.utils.validation_metrics import (
    drmsd,
    gdt_ts,
    gdt_ha,
)
35
36
from openfold.utils.import_weights import (
    import_jax_weights_,
37
    import_openfold_weights_
38
)
Marta's avatar
Marta committed
39
40
from openfold.utils.logger import PerformanceLoggingCallback

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
41
42
43
44
45

class OpenFoldWrapper(pl.LightningModule):
    def __init__(self, config):
        super(OpenFoldWrapper, self).__init__()
        self.config = config
46
        self.model = AlphaFold(config)
47
        self.is_multimer = self.config.globals.is_multimer
48

49
        self.loss = AlphaFoldLoss(config.loss)
50

51
52
53
        self.ema = ExponentialMovingAverage(
            model=self.model, decay=config.ema.decay
        )
54

55
        self.cached_weights = None
56
        self.last_lr_step = -1
Jennifer's avatar
Jennifer committed
57
        self.save_hyperparameters()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
58
59
60
61

    def forward(self, batch):
        return self.model(batch)

62
63
64
65
    def _log(self, loss_breakdown, batch, outputs, train=True):
        phase = "train" if train else "val"
        for loss_name, indiv_loss in loss_breakdown.items():
            self.log(
66
67
                f"{phase}/{loss_name}",
                indiv_loss,
Jennifer's avatar
Jennifer committed
68
                prog_bar=(loss_name == 'loss'),
69
70
71
                on_step=train, on_epoch=(not train), logger=True,
            )

72
            if (train):
73
74
                self.log(
                    f"{phase}/{loss_name}_epoch",
Jennifer's avatar
Jennifer committed
75
                    indiv_loss, 
76
77
78
79
80
                    on_step=False, on_epoch=True, logger=True,
                )

        with torch.no_grad():
            other_metrics = self._compute_validation_metrics(
81
                batch,
82
83
84
85
                outputs,
                superimposition_metrics=(not train)
            )

86
        for k, v in other_metrics.items():
87
            self.log(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
88
89
                f"{phase}/{k}",
                torch.mean(v),
Jennifer's avatar
Jennifer committed
90
91
                prog_bar = (k == 'loss'),
                on_step=False, on_epoch=True, logger=True,
92
93
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
94
    def training_step(self, batch, batch_idx):
95
        if (self.ema.device != batch["aatype"].device):
96
97
            self.ema.to(batch["aatype"].device)

98
99
        ground_truth = batch.pop('gt_features', None)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
100
101
        # Run the model
        outputs = self(batch)
102

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
103
104
105
        # Remove the recycling dimension
        batch = tensor_tree_map(lambda t: t[..., -1], batch)

106
107
108
109
110
        if self.is_multimer:
            batch = multi_chain_permutation_align(out=outputs,
                                                  features=batch,
                                                  ground_truth=ground_truth)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
111
        # Compute loss
112
113
114
        loss, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
115

116
117
        # Log it
        self._log(loss_breakdown, batch, outputs)
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
118

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
119
        return loss
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
120

121
122
    def on_before_zero_grad(self, *args, **kwargs):
        self.ema.update(self.model)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
123

124
125
    def validation_step(self, batch, batch_idx):
        # At the start of validation, load the EMA weights
126
        if (self.cached_weights is None):
127
            # model.state_dict() contains references to model weights rather
128
            # than copies. Therefore, we need to clone them before calling
129
            # load_state_dict().
130
131
132
            def clone_param(t): return t.detach().clone()
            self.cached_weights = tensor_tree_map(
                clone_param, self.model.state_dict())
133
            self.model.load_state_dict(self.ema.state_dict()["params"])
134
135
136

        ground_truth = batch.pop('gt_features', None)

137
        # Run the model
138
139
        outputs = self(batch)
        batch = tensor_tree_map(lambda t: t[..., -1], batch)
140
141

        batch["use_clamped_fape"] = 0.
142
143
144
145
146
147
148

        if self.is_multimer:
            batch = multi_chain_permutation_align(out=outputs,
                                                  features=batch,
                                                  ground_truth=ground_truth)

        # Compute loss and other metrics
149
150
        _, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
151
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
152

153
        self._log(loss_breakdown, batch, outputs, train=False)
154

155
    def on_validation_epoch_end(self):
156
157
158
        # Restore the model weights to normal
        self.model.load_state_dict(self.cached_weights)
        self.cached_weights = None
159

160
161
162
163
164
    def _compute_validation_metrics(self,
                                    batch,
                                    outputs,
                                    superimposition_metrics=False
                                    ):
165
        metrics = {}
166

167
168
169
        gt_coords = batch["all_atom_positions"]
        pred_coords = outputs["final_atom_positions"]
        all_atom_mask = batch["all_atom_mask"]
170

171
172
173
174
175
176
177
        # This is super janky for superimposition. Fix later
        gt_coords_masked = gt_coords * all_atom_mask[..., None]
        pred_coords_masked = pred_coords * all_atom_mask[..., None]
        ca_pos = residue_constants.atom_order["CA"]
        gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
        pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
        all_atom_mask_ca = all_atom_mask[..., ca_pos]
178

179
180
181
182
183
184
185
        lddt_ca_score = lddt_ca(
            pred_coords,
            gt_coords,
            all_atom_mask,
            eps=self.config.globals.eps,
            per_residue=False,
        )
186

187
        metrics["lddt_ca"] = lddt_ca_score
188

189
190
191
        drmsd_ca_score = drmsd(
            pred_coords_masked_ca,
            gt_coords_masked_ca,
192
            mask=all_atom_mask_ca,  # still required here to compute n
193
        )
194

195
        metrics["drmsd_ca"] = drmsd_ca_score
196
197

        if (superimposition_metrics):
198
199
200
201
202
203
204
205
206
207
208
209
210
            superimposed_pred, alignment_rmsd = superimpose(
                gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
            )
            gdt_ts_score = gdt_ts(
                superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
            )
            gdt_ha_score = gdt_ha(
                superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
            )

            metrics["alignment_rmsd"] = alignment_rmsd
            metrics["gdt_ts"] = gdt_ts_score
            metrics["gdt_ha"] = gdt_ha_score
211

212
213
        return metrics

214
215
216
217
218
219
220
221
222
    def configure_optimizers(self,
                             learning_rate: float = 1e-3,
                             eps: float = 1e-5,
                             ) -> torch.optim.Adam:
        #        return torch.optim.Adam(
        #            self.model.parameters(),
        #            lr=learning_rate,
        #            eps=eps
        #        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
223
        # Ignored as long as a DeepSpeed optimizer is configured
224
        optimizer = torch.optim.Adam(
225
226
            self.model.parameters(),
            lr=learning_rate,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
227
228
            eps=eps
        )
229
230
231
232
233
234

        if self.last_lr_step != -1:
            for group in optimizer.param_groups:
                if 'initial_lr' not in group:
                    group['initial_lr'] = learning_rate

235
236
        lr_scheduler = AlphaFoldLRScheduler(
            optimizer,
237
            last_epoch=self.last_lr_step
238
        )
239

240
241
242
243
244
245
246
247
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "step",
                "name": "AlphaFoldLRScheduler",
            }
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
248

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
249
    def on_load_checkpoint(self, checkpoint):
250
        ema = checkpoint["ema"]
251
252
253
        if (not self.model.template_config.enabled):
            ema["params"] = {k: v for k,
                             v in ema["params"].items() if not "template" in k}
254
        self.ema.load_state_dict(ema)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
255

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
256
257
258
    def on_save_checkpoint(self, checkpoint):
        checkpoint["ema"] = self.ema.state_dict()

259
260
261
    def resume_last_lr_step(self, lr_step):
        self.last_lr_step = lr_step

262
263
    def load_from_jax(self, jax_path):
        model_basename = os.path.splitext(
264
265
266
            os.path.basename(
                os.path.normpath(jax_path)
            )
267
268
269
        )[0]
        model_version = "_".join(model_basename.split("_")[1:])
        import_jax_weights_(
270
            self.model, jax_path, version=model_version
271
272
        )

273
274
275
276
277
278
279
280
281
282
283
284
def get_model_state_dict_from_ds_checkpoint(checkpoint_dir):
    latest_path = os.path.join(checkpoint_dir, 'latest')
    if os.path.isfile(latest_path):
        with open(latest_path, 'r') as fd:
            tag = fd.read().strip()
    else:
        raise ValueError(f"Unable to find 'latest' file at {latest_path}")

    ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
    _DS_CHECKPOINT_VERSION = 2  # based on manual parsing of checkpoint files
    state_file = zero_to_fp32.get_model_state_file(ds_checkpoint_dir, _DS_CHECKPOINT_VERSION)
    return torch.load(state_file)
285

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
286
def main(args):
287
    if (args.seed is not None):
288
        seed_everything(args.seed, workers=True)
289

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
290
    config = model_config(
291
292
        args.config_preset,
        train=True,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
293
        low_prec=(str(args.precision) == "16")
294
    ) 
295
296
297
298
299
    if args.experiment_config_json: 
        with open(args.experiment_config_json, 'r') as f:
            custom_config_dict = json.load(f)
        config.update_from_flattened_dict(custom_config_dict)

300
301
    model_module = OpenFoldWrapper(config)

302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    if args.resume_from_ckpt:
        if args.resume_model_weights_only:
            # Load the checkpoint
            if os.path.isdir(args.resume_from_ckpt):
                sd = get_fp32_state_dict_from_zero_checkpoint(
                    args.resume_from_ckpt)
            else:
                sd = torch.load(args.resume_from_ckpt)
            # Process the state dict
            if 'module' in sd:
                sd = {k[len('module.'):]: v for k, v in sd['module'].items()}
                import_openfold_weights_(model=model_module, state_dict=sd)
            elif 'state_dict' in sd:
                import_openfold_weights_(
                    model=model_module, state_dict=sd['state_dict'])
            else:
                # Loading from pre-trained model
                sd = {'model.'+k: v for k, v in sd.items()}
                import_openfold_weights_(model=model_module, state_dict=sd)
            logging.info("Successfully loaded model weights...")

        else:  # Loads a checkpoint to start from a specific time step
            if os.path.isdir(args.resume_from_ckpt):
325
                sd = get_model_state_dict_from_ds_checkpoint(args.resume_from_ckpt)
326
327
            else:
                sd = torch.load(args.resume_from_ckpt)
328
            last_global_step = int(sd['global_step'])
329
330
331
332
            model_module.resume_last_lr_step(last_global_step)
            logging.info("Successfully loaded last lr step...")

    if args.resume_from_jax_params:
Lucas Bickmann's avatar
Lucas Bickmann committed
333
        model_module.load_from_jax(args.resume_from_jax_params)
334
335
336
        logging.info(
            f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")

337
    # TorchScript components of the model
338
    if (args.script_modules):
339
        script_preset_(model_module)
340

341
342
    if "multimer" in args.config_preset:
        data_module = OpenFoldMultimerDataModule(
343
344
345
346
            config=config.data,
            batch_seed=args.seed,
            **vars(args)
        )
347
348
    else:
        data_module = OpenFoldDataModule(
349
            config=config.data,
350
351
352
            batch_seed=args.seed,
            **vars(args)
        )
353

354
355
    data_module.prepare_data()
    data_module.setup()
356

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
357
    callbacks = []
358
    if (args.checkpoint_every_epoch):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
359
        mc = ModelCheckpoint(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
360
            every_n_epochs=1,
361
362
            auto_insert_metric_name=False,
            save_top_k=-1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
363
364
365
        )
        callbacks.append(mc)

366
    if (args.early_stopping):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
367
        es = EarlyStoppingVerbose(
368
            monitor="val/lddt_ca",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
369
370
371
            min_delta=args.min_delta,
            patience=args.patience,
            verbose=False,
372
            mode="max",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
373
374
375
376
            check_finite=True,
            strict=True,
        )
        callbacks.append(es)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
377

378
    if (args.log_performance):
Marta's avatar
Marta committed
379
380
        global_batch_size = args.num_nodes * args.gpus
        perf = PerformanceLoggingCallback(
Marta's avatar
Marta committed
381
            log_file=os.path.join(args.output_dir, "performance_log.json"),
Marta's avatar
Marta committed
382
383
384
            global_batch_size=global_batch_size,
        )
        callbacks.append(perf)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
385

386
    if (args.log_lr):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
387
388
389
        lr_monitor = LearningRateMonitor(logging_interval="step")
        callbacks.append(lr_monitor)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
390
    loggers = []
391
    if (args.wandb):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
392
393
394
395
396
        wdb_logger = WandbLogger(
            name=args.experiment_name,
            save_dir=args.output_dir,
            id=args.wandb_id,
            project=args.wandb_project,
Jennifer's avatar
Jennifer committed
397
            config=config.to_dict(),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
398
399
400
401
            **{"entity": args.wandb_entity}
        )
        loggers.append(wdb_logger)

402
403
    if (args.deepspeed_config_path is not None):
        strategy = DeepSpeedStrategy(
404
405
            config=args.deepspeed_config_path,
        )
406
        if (args.wandb):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
407
            wdb_logger.experiment.save(args.deepspeed_config_path)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
408
            wdb_logger.experiment.save("openfold/config.py")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
409
    elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
410
        strategy = DDPStrategy(find_unused_parameters=False)
411
412
    else:
        strategy = None
413
414

    if (args.wandb):
415
416
417
418
        freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
        os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
        wdb_logger.experiment.save(f"{freeze_path}")

419
    # Raw dump of all args from pl.Trainer constructor
Jennifer's avatar
Jennifer committed
420
    trainer_kws = ['num_nodes', 'precision', 'max_epochs', 'log_every_n_steps', 'flush_logs_ever_n_steps', 'num_sanity_val_steps']
421
422
423
424
425
426
427
428
429
430
    trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws}
    trainer_args.update({
        'default_root_dir': args.output_dir,
        'strategy': strategy,
        'callbacks': callbacks,
        'logger': loggers,
    })
    trainer = pl.Trainer(**trainer_args)

    if (args.resume_model_weights_only):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
431
432
433
434
435
        ckpt_path = None
    else:
        ckpt_path = args.resume_from_ckpt

    trainer.fit(
436
        model_module,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
437
438
        datamodule=data_module,
        ckpt_path=ckpt_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
439
440
441
    )


Marta's avatar
Marta committed
442
443
444
445
446
447
448
449
450
451
def bool_type(bool_str: str):
    bool_str_lower = bool_str.lower()
    if bool_str_lower in ('false', 'f', 'no', 'n', '0'):
        return False
    elif bool_str_lower in ('true', 't', 'yes', 'y', '1'):
        return True
    else:
        raise ValueError(f'Cannot interpret {bool_str} as bool')


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
452
453
454
455
456
457
458
459
460
461
462
463
464
465
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "train_data_dir", type=str,
        help="Directory containing training mmCIF files"
    )
    parser.add_argument(
        "train_alignment_dir", type=str,
        help="Directory containing precomputed training alignments"
    )
    parser.add_argument(
        "template_mmcif_dir", type=str,
        help="Directory containing mmCIF files to search for templates"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
466
467
468
469
470
    parser.add_argument(
        "output_dir", type=str,
        help='''Directory in which to output checkpoints, logs, etc. Ignored
                if not on rank 0'''
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
471
472
    parser.add_argument(
        "max_template_date", type=str,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
473
474
        help='''Cutoff for all templates. In training mode, templates are also 
                filtered by the release date of the target'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
475
    )
476
477
    parser.add_argument(
        "--train_mmcif_data_cache_path", type=str, default=None,
478
479
        help="Path to the json file which records all the information of mmcif structures used during training"
    )
480
    parser.add_argument(
481
        "--use_single_seq_mode", type=str, default=False,
482
        help="Use single sequence embeddings instead of MSAs."
483
    )
484
485
486
487
488
489
490
491
    parser.add_argument(
        "--distillation_data_dir", type=str, default=None,
        help="Directory containing training PDB files"
    )
    parser.add_argument(
        "--distillation_alignment_dir", type=str, default=None,
        help="Directory containing precomputed distillation alignments"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
492
493
494
495
496
497
498
499
    parser.add_argument(
        "--val_data_dir", type=str, default=None,
        help="Directory containing validation mmCIF files"
    )
    parser.add_argument(
        "--val_alignment_dir", type=str, default=None,
        help="Directory containing precomputed validation alignments"
    )
500
501
    parser.add_argument(
        "--val_mmcif_data_cache_path", type=str, default=None,
Dingquan Yu's avatar
Dingquan Yu committed
502
        help="path to the json file which records all the information of mmcif structures used during validation"
503
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
504
505
506
507
508
    parser.add_argument(
        "--kalign_binary_path", type=str, default='/usr/bin/kalign',
        help="Path to the kalign binary"
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
509
510
511
512
        "--train_filter_path", type=str, default=None,
        help='''Optional path to a text file containing names of training
                examples to include, one per line. Used to filter the training 
                set'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
513
514
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
515
516
        "--distillation_filter_path", type=str, default=None,
        help="""See --train_filter_path"""
517
    )
518
519
520
521
522
    parser.add_argument(
        "--obsolete_pdbs_file_path", type=str, default=None,
        help="""Path to obsolete.dat file containing list of obsolete PDBs and 
             their replacements."""
    )
523
524
    parser.add_argument(
        "--template_release_dates_cache_path", type=str, default=None,
525
526
        help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
                files."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
527
528
    )
    parser.add_argument(
Marta's avatar
Marta committed
529
        "--use_small_bfd", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
530
531
532
        help="Whether to use a reduced version of the BFD database"
    )
    parser.add_argument(
533
534
        "--seed", type=int, default=None,
        help="Random seed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
535
    )
536
537
538
539
    parser.add_argument(
        "--deepspeed_config_path", type=str, default=None,
        help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
540
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
541
542
        "--checkpoint_every_epoch", action="store_true", default=False,
        help="""Whether to checkpoint at the end of every training epoch"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
543
544
    )
    parser.add_argument(
Marta's avatar
Marta committed
545
        "--early_stopping", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
546
547
548
549
550
551
552
553
554
555
556
        help="Whether to stop training when validation loss fails to decrease"
    )
    parser.add_argument(
        "--min_delta", type=float, default=0,
        help="""The smallest decrease in validation loss that counts as an 
                improvement for the purposes of early stopping"""
    )
    parser.add_argument(
        "--patience", type=int, default=3,
        help="Early stopping patience"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
557
558
559
560
561
    parser.add_argument(
        "--resume_from_ckpt", type=str, default=None,
        help="Path to a model checkpoint from which to restore training state"
    )
    parser.add_argument(
Marta's avatar
Marta committed
562
        "--resume_model_weights_only", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
563
564
        help="Whether to load just model weights as opposed to training state"
    )
Lucas Bickmann's avatar
Lucas Bickmann committed
565
    parser.add_argument(
566
567
        "--resume_from_jax_params", type=str, default=None,
        help="""Path to an .npz JAX parameter file with which to initialize the model"""
Lucas Bickmann's avatar
Lucas Bickmann committed
568
    )
Marta's avatar
Marta committed
569
    parser.add_argument(
570
        "--log_performance", type=bool_type, default=False,
Marta's avatar
Marta committed
571
572
        help="Measure performance"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
573
574
    parser.add_argument(
        "--wandb", action="store_true", default=False,
575
        help="Whether to log metrics to Weights & Biases"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
576
577
578
    )
    parser.add_argument(
        "--experiment_name", type=str, default=None,
579
        help="Name of the current experiment. Used for wandb logging"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
580
581
582
    )
    parser.add_argument(
        "--wandb_id", type=str, default=None,
583
        help="ID of a previous run to be resumed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
584
585
586
    )
    parser.add_argument(
        "--wandb_project", type=str, default=None,
587
        help="Name of the wandb project to which this run will belong"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
588
589
590
    )
    parser.add_argument(
        "--wandb_entity", type=str, default=None,
591
        help="wandb username or team name to which runs are attributed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
592
    )
593
594
595
596
    parser.add_argument(
        "--script_modules", type=bool_type, default=False,
        help="Whether to TorchScript eligible components of them model"
    )
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
597
    parser.add_argument(
598
        "--train_chain_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
599
600
    )
    parser.add_argument(
601
        "--distillation_chain_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
602
603
604
    )
    parser.add_argument(
        "--train_epoch_len", type=int, default=10000,
605
606
607
608
609
610
        help=(
            "The virtual length of each training epoch. Stochastic filtering "
            "of training data means that training datasets have no "
            "well-defined length. This virtual length affects frequency of "
            "validation & checkpointing (by default, one of each per epoch)."
        )
611
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
612
    parser.add_argument(
613
614
        "--log_lr", action="store_true", default=False,
        help="Whether to log the actual learning rate"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
615
    )
616
    parser.add_argument(
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
        "--config_preset", type=str, default="initial_training",
        help=(
            'Config setting. Choose e.g. "initial_training", "finetuning", '
            '"model_1", etc. By default, the actual values in the config are '
            'used.'
        )
    )
    parser.add_argument(
        "--_distillation_structure_index_path", type=str, default=None,
    )
    parser.add_argument(
        "--alignment_index_path", type=str, default=None,
        help="Training alignment index. See the README for instructions."
    )
    parser.add_argument(
        "--distillation_alignment_index_path", type=str, default=None,
        help="Distillation alignment index. See the README for instructions."
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
634
    )
635
636
637
    parser.add_argument(
        "--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
    )
638
    parser.add_argument(
Jennifer's avatar
Jennifer committed
639
        "--gpus", type=int, default=1, help='For determining optimal strategy and effective batch size.'
640
    )
Jennifer's avatar
Jennifer committed
641
642
643
644

    trainer_group = parser.add_argument_group('PyTorch Lightning Trainer Args') 
    trainer_group.add_argument(
        "--num_nodes", type=int, default=1,
645
    )
Jennifer's avatar
Jennifer committed
646
647
    trainer_group.add_argument(
        "--precision", type=str, default='bf16', help='Sets precision, lower precision improves runtime performance.'
648
    )
Jennifer's avatar
Jennifer committed
649
    trainer_group.add_argument(
650
651
        "--max_epochs", type=int, default=1,
    )
Jennifer's avatar
Jennifer committed
652
    trainer_group.add_argument(
653
654
        "--log_every_n_steps", type=int, default=25,
    )
Jennifer's avatar
Jennifer committed
655
    trainer_group.add_argument(
Jennifer's avatar
Jennifer committed
656
657
        "--flush_logs_every_n_steps", type=int, default=5,
    )
Jennifer's avatar
Jennifer committed
658
    trainer_group.add_argument(
659
660
661
        "--num_sanity_val_steps", type=int, default=0,
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
662
663
    args = parser.parse_args()

664
665
    if (args.seed is None and
        ((args.gpus is not None and args.gpus > 1) or
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
666
667
668
         (args.num_nodes is not None and args.num_nodes > 1))):
        raise ValueError("For distributed training, --seed must be specified")

669
    if (str(args.precision) == "16" and args.deepspeed_config_path is not None):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
670
671
        raise ValueError("DeepSpeed and FP16 training are not compatible")

672
673
674
    if (args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
        raise ValueError(
            "Choose between loading pretrained Jax-weights and a checkpoint-path")
675

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
676
    # This re-applies the training-time filters at the beginning of every epoch
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
677
    args.reload_dataloaders_every_n_epochs = 1
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
678

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
679
    main(args)