train_openfold.py 23.1 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
import argparse
import logging
import os
4
import sys
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
5
6

import pytorch_lightning as pl
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
7
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
8
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
9
from pytorch_lightning.loggers import WandbLogger
10
from pytorch_lightning.strategies import DeepSpeedStrategy, DDPStrategy
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
11
12
13
import torch

from openfold.config import model_config
14
from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
15
from openfold.model.model import AlphaFold
16
from openfold.model.torchscript import script_preset_
17
from openfold.np import residue_constants
18
from openfold.utils.argparse_utils import remove_arguments
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
20
21
from openfold.utils.callbacks import (
    EarlyStoppingVerbose,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
22
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
23
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
24
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
25
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
26
from openfold.utils.seed import seed_everything
27
from openfold.utils.superimposition import superimpose
28
from openfold.utils.tensor_utils import tensor_tree_map
29
30
31
32
33
from openfold.utils.validation_metrics import (
    drmsd,
    gdt_ts,
    gdt_ha,
)
34
35
from openfold.utils.import_weights import (
    import_jax_weights_,
36
    import_openfold_weights_
37
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
38
from scripts.zero_to_fp32 import (
39
40
    get_fp32_state_dict_from_zero_checkpoint,
    get_global_step_from_zero_checkpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
41
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
42

Marta's avatar
Marta committed
43
44
from openfold.utils.logger import PerformanceLoggingCallback

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
45
46
47
48
49

class OpenFoldWrapper(pl.LightningModule):
    def __init__(self, config):
        super(OpenFoldWrapper, self).__init__()
        self.config = config
50
        self.model = AlphaFold(config)
51
        self.is_multimer = self.config.globals.is_multimer
52

53
        self.loss = AlphaFoldLoss(config.loss)
54

55
56
57
        self.ema = ExponentialMovingAverage(
            model=self.model, decay=config.ema.decay
        )
58

59
        self.cached_weights = None
60
        self.last_lr_step = -1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
61
62
63
64

    def forward(self, batch):
        return self.model(batch)

65
66
67
68
    def _log(self, loss_breakdown, batch, outputs, train=True):
        phase = "train" if train else "val"
        for loss_name, indiv_loss in loss_breakdown.items():
            self.log(
69
70
                f"{phase}/{loss_name}",
                indiv_loss,
71
72
73
                on_step=train, on_epoch=(not train), logger=True,
            )

74
            if (train):
75
76
77
78
79
80
81
82
                self.log(
                    f"{phase}/{loss_name}_epoch",
                    indiv_loss,
                    on_step=False, on_epoch=True, logger=True,
                )

        with torch.no_grad():
            other_metrics = self._compute_validation_metrics(
83
                batch,
84
85
86
87
                outputs,
                superimposition_metrics=(not train)
            )

88
        for k, v in other_metrics.items():
89
            self.log(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
90
91
                f"{phase}/{k}",
                torch.mean(v),
92
93
94
                on_step=False, on_epoch=True, logger=True
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
95
    def training_step(self, batch, batch_idx):
96
        if (self.ema.device != batch["aatype"].device):
97
98
            self.ema.to(batch["aatype"].device)

99
100
        ground_truth = batch.pop('gt_features', None)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
101
102
        # Run the model
        outputs = self(batch)
103

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
104
105
106
        # Remove the recycling dimension
        batch = tensor_tree_map(lambda t: t[..., -1], batch)

107
108
109
110
111
        if self.is_multimer:
            batch = multi_chain_permutation_align(out=outputs,
                                                  features=batch,
                                                  ground_truth=ground_truth)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
112
        # Compute loss
113
114
115
        loss, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
116

117
118
        # Log it
        self._log(loss_breakdown, batch, outputs)
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
119

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
120
        return loss
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
121

122
123
    def on_before_zero_grad(self, *args, **kwargs):
        self.ema.update(self.model)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
124

125
126
    def validation_step(self, batch, batch_idx):
        # At the start of validation, load the EMA weights
127
        if (self.cached_weights is None):
128
            # model.state_dict() contains references to model weights rather
129
            # than copies. Therefore, we need to clone them before calling
130
            # load_state_dict().
131
132
133
            def clone_param(t): return t.detach().clone()
            self.cached_weights = tensor_tree_map(
                clone_param, self.model.state_dict())
134
            self.model.load_state_dict(self.ema.state_dict()["params"])
135
136
137

        ground_truth = batch.pop('gt_features', None)

138
        # Run the model
139
140
        outputs = self(batch)
        batch = tensor_tree_map(lambda t: t[..., -1], batch)
141
142

        batch["use_clamped_fape"] = 0.
143
144
145
146
147
148
149

        if self.is_multimer:
            batch = multi_chain_permutation_align(out=outputs,
                                                  features=batch,
                                                  ground_truth=ground_truth)

        # Compute loss and other metrics
150
151
        _, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
152
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
153

154
        self._log(loss_breakdown, batch, outputs, train=False)
155
156

    def on_validation_epoch_end(self, _):
157
158
159
        # Restore the model weights to normal
        self.model.load_state_dict(self.cached_weights)
        self.cached_weights = None
160

161
162
163
164
165
    def _compute_validation_metrics(self,
                                    batch,
                                    outputs,
                                    superimposition_metrics=False
                                    ):
166
        metrics = {}
167

168
169
170
        gt_coords = batch["all_atom_positions"]
        pred_coords = outputs["final_atom_positions"]
        all_atom_mask = batch["all_atom_mask"]
171

172
173
174
175
176
177
178
        # This is super janky for superimposition. Fix later
        gt_coords_masked = gt_coords * all_atom_mask[..., None]
        pred_coords_masked = pred_coords * all_atom_mask[..., None]
        ca_pos = residue_constants.atom_order["CA"]
        gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
        pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
        all_atom_mask_ca = all_atom_mask[..., ca_pos]
179

180
181
182
183
184
185
186
        lddt_ca_score = lddt_ca(
            pred_coords,
            gt_coords,
            all_atom_mask,
            eps=self.config.globals.eps,
            per_residue=False,
        )
187

188
        metrics["lddt_ca"] = lddt_ca_score
189

190
191
192
        drmsd_ca_score = drmsd(
            pred_coords_masked_ca,
            gt_coords_masked_ca,
193
            mask=all_atom_mask_ca,  # still required here to compute n
194
        )
195

196
        metrics["drmsd_ca"] = drmsd_ca_score
197
198

        if (superimposition_metrics):
199
200
201
202
203
204
205
206
207
208
209
210
211
            superimposed_pred, alignment_rmsd = superimpose(
                gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
            )
            gdt_ts_score = gdt_ts(
                superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
            )
            gdt_ha_score = gdt_ha(
                superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
            )

            metrics["alignment_rmsd"] = alignment_rmsd
            metrics["gdt_ts"] = gdt_ts_score
            metrics["gdt_ha"] = gdt_ha_score
212

213
214
        return metrics

215
216
217
218
219
220
221
222
223
    def configure_optimizers(self,
                             learning_rate: float = 1e-3,
                             eps: float = 1e-5,
                             ) -> torch.optim.Adam:
        #        return torch.optim.Adam(
        #            self.model.parameters(),
        #            lr=learning_rate,
        #            eps=eps
        #        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
224
        # Ignored as long as a DeepSpeed optimizer is configured
225
        optimizer = torch.optim.Adam(
226
227
            self.model.parameters(),
            lr=learning_rate,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
228
229
            eps=eps
        )
230
231
232
233
234
235

        if self.last_lr_step != -1:
            for group in optimizer.param_groups:
                if 'initial_lr' not in group:
                    group['initial_lr'] = learning_rate

236
237
238
        lr_scheduler = AlphaFoldLRScheduler(
            optimizer,
        )
239

240
241
242
243
244
245
246
247
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "step",
                "name": "AlphaFoldLRScheduler",
            }
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
248

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
249
    def on_load_checkpoint(self, checkpoint):
250
        ema = checkpoint["ema"]
251
252
253
        if (not self.model.template_config.enabled):
            ema["params"] = {k: v for k,
                             v in ema["params"].items() if not "template" in k}
254
        self.ema.load_state_dict(ema)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
255

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
256
257
258
    def on_save_checkpoint(self, checkpoint):
        checkpoint["ema"] = self.ema.state_dict()

259
260
261
    def resume_last_lr_step(self, lr_step):
        self.last_lr_step = lr_step

262
263
    def load_from_jax(self, jax_path):
        model_basename = os.path.splitext(
264
265
266
            os.path.basename(
                os.path.normpath(jax_path)
            )
267
268
269
        )[0]
        model_version = "_".join(model_basename.split("_")[1:])
        import_jax_weights_(
270
            self.model, jax_path, version=model_version
271
272
        )

273

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
274
def main(args):
275
276
    if (args.seed is not None):
        seed_everything(args.seed)
277

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
278
    config = model_config(
279
280
        args.config_preset,
        train=True,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
281
        low_prec=(str(args.precision) == "16")
282
    )
283
284
    model_module = OpenFoldWrapper(config)

285
286
287
288
    if (args.resume_from_ckpt):
        if (os.path.isdir(args.resume_from_ckpt)):
            last_global_step = get_global_step_from_zero_checkpoint(
                args.resume_from_ckpt)
289
290
291
        else:
            sd = torch.load(args.resume_from_ckpt)
            last_global_step = int(sd['global_step'])
292
293
        model_module.resume_last_lr_step(last_global_step)
        logging.info("Successfully loaded last lr step...")
294
295
296
297
    if (args.resume_from_ckpt and args.resume_model_weights_only):
        if (os.path.isdir(args.resume_from_ckpt)):
            sd = get_fp32_state_dict_from_zero_checkpoint(
                args.resume_from_ckpt)
298
299
        else:
            sd = torch.load(args.resume_from_ckpt)
300
        sd = {k[len("module."):]: v for k, v in sd.items()}
301
        import_openfold_weights_(model=model_module, state_dict=sd)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
302
        logging.info("Successfully loaded model weights...")
303
    if (args.resume_from_jax_params):
Lucas Bickmann's avatar
Lucas Bickmann committed
304
        model_module.load_from_jax(args.resume_from_jax_params)
305
306
307
        logging.info(
            f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")

308
    # TorchScript components of the model
309
    if (args.script_modules):
310
        script_preset_(model_module)
311

312
313
    if "multimer" in args.config_preset:
        data_module = OpenFoldMultimerDataModule(
314
315
316
317
            config=config.data,
            batch_seed=args.seed,
            **vars(args)
        )
318
319
    else:
        data_module = OpenFoldDataModule(
320
            config=config.data,
321
322
323
            batch_seed=args.seed,
            **vars(args)
        )
324

325
326
    data_module.prepare_data()
    data_module.setup()
327

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
328
    callbacks = []
329
    if (args.checkpoint_every_epoch):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
330
        mc = ModelCheckpoint(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
331
            every_n_epochs=1,
332
333
            auto_insert_metric_name=False,
            save_top_k=-1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
334
335
336
        )
        callbacks.append(mc)

337
    if (args.early_stopping):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
338
        es = EarlyStoppingVerbose(
339
            monitor="val/lddt_ca",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
340
341
342
            min_delta=args.min_delta,
            patience=args.patience,
            verbose=False,
343
            mode="max",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
344
345
346
347
            check_finite=True,
            strict=True,
        )
        callbacks.append(es)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
348

349
    if (args.log_performance):
Marta's avatar
Marta committed
350
351
        global_batch_size = args.num_nodes * args.gpus
        perf = PerformanceLoggingCallback(
Marta's avatar
Marta committed
352
            log_file=os.path.join(args.output_dir, "performance_log.json"),
Marta's avatar
Marta committed
353
354
355
            global_batch_size=global_batch_size,
        )
        callbacks.append(perf)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
356

357
    if (args.log_lr):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
358
359
360
        lr_monitor = LearningRateMonitor(logging_interval="step")
        callbacks.append(lr_monitor)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
361
    loggers = []
362
    if (args.wandb):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
363
364
365
366
367
368
369
370
371
        wdb_logger = WandbLogger(
            name=args.experiment_name,
            save_dir=args.output_dir,
            id=args.wandb_id,
            project=args.wandb_project,
            **{"entity": args.wandb_entity}
        )
        loggers.append(wdb_logger)

372
373
    if (args.deepspeed_config_path is not None):
        strategy = DeepSpeedStrategy(
374
375
            config=args.deepspeed_config_path,
        )
376
        if (args.wandb):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
377
            wdb_logger.experiment.save(args.deepspeed_config_path)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
378
            wdb_logger.experiment.save("openfold/config.py")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
379
    elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
380
        strategy = DDPStrategy(find_unused_parameters=False)
381
382
    else:
        strategy = None
383
384

    if (args.wandb):
385
386
387
388
        freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
        os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
        wdb_logger.experiment.save(f"{freeze_path}")

389
390
391
392
393
394
395
396
397
398
399
400
401
402
    # Raw dump of all args from pl.Trainer constructor
    trainer_kws = set([
        'accelerator', 'strategy', 'devices', 'num_nodes', 'precision', 'logger', 'callbacks', 'fast_dev_run', 'max_epochs', 'min_epochs', 'max_steps', 'min_steps', 'max_tim', 'limit_train_batches', 'limit_val_batches', 'limit_test_batches', 'limit_predict_batches', 'overfit_batches', 'val_check_interval', 'check_val_every_n_epoch', 'num_sanity_val_steps', 'log_every_n_steps', 'enable_checkpointing', 'enable_progress_bar', 'enable_model_summary', 'accumulate_grad_batches', 'gradient_clip_val', 'gradient_clip_algorithm', 'deterministic', 'benchmark', 'inference_mode', 'use_distributed_sampler', 'profiler', 'detect_anomaly', 'barebones', 'plugins', 'sync_batchnorm', 'reload_dataloaders_every_n_epochs', 'default_root_dir',
    ])
    trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws}
    trainer_args.update({
        'default_root_dir': args.output_dir,
        'strategy': strategy,
        'callbacks': callbacks,
        'logger': loggers,
    })
    trainer = pl.Trainer(**trainer_args)

    if (args.resume_model_weights_only):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
403
404
405
406
407
        ckpt_path = None
    else:
        ckpt_path = args.resume_from_ckpt

    trainer.fit(
408
        model_module,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
409
410
        datamodule=data_module,
        ckpt_path=ckpt_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
411
412
413
    )


Marta's avatar
Marta committed
414
415
416
417
418
419
420
421
422
423
def bool_type(bool_str: str):
    bool_str_lower = bool_str.lower()
    if bool_str_lower in ('false', 'f', 'no', 'n', '0'):
        return False
    elif bool_str_lower in ('true', 't', 'yes', 'y', '1'):
        return True
    else:
        raise ValueError(f'Cannot interpret {bool_str} as bool')


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
424
425
426
427
428
429
430
431
432
433
434
435
436
437
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "train_data_dir", type=str,
        help="Directory containing training mmCIF files"
    )
    parser.add_argument(
        "train_alignment_dir", type=str,
        help="Directory containing precomputed training alignments"
    )
    parser.add_argument(
        "template_mmcif_dir", type=str,
        help="Directory containing mmCIF files to search for templates"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
438
439
440
441
442
    parser.add_argument(
        "output_dir", type=str,
        help='''Directory in which to output checkpoints, logs, etc. Ignored
                if not on rank 0'''
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
443
444
    parser.add_argument(
        "max_template_date", type=str,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
445
446
        help='''Cutoff for all templates. In training mode, templates are also 
                filtered by the release date of the target'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
447
    )
448
449
    parser.add_argument(
        "--train_mmcif_data_cache_path", type=str, default=None,
450
451
        help="Path to the json file which records all the information of mmcif structures used during training"
    )
452
    parser.add_argument(
453
        "--use_single_seq_mode", type=str, default=False,
454
        help="Use single sequence embeddings instead of MSAs."
455
    )
456
457
458
459
460
461
462
463
    parser.add_argument(
        "--distillation_data_dir", type=str, default=None,
        help="Directory containing training PDB files"
    )
    parser.add_argument(
        "--distillation_alignment_dir", type=str, default=None,
        help="Directory containing precomputed distillation alignments"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
464
465
466
467
468
469
470
471
    parser.add_argument(
        "--val_data_dir", type=str, default=None,
        help="Directory containing validation mmCIF files"
    )
    parser.add_argument(
        "--val_alignment_dir", type=str, default=None,
        help="Directory containing precomputed validation alignments"
    )
472
473
    parser.add_argument(
        "--val_mmcif_data_cache_path", type=str, default=None,
Dingquan Yu's avatar
Dingquan Yu committed
474
        help="path to the json file which records all the information of mmcif structures used during validation"
475
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
476
477
478
479
480
    parser.add_argument(
        "--kalign_binary_path", type=str, default='/usr/bin/kalign',
        help="Path to the kalign binary"
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
481
482
483
484
        "--train_filter_path", type=str, default=None,
        help='''Optional path to a text file containing names of training
                examples to include, one per line. Used to filter the training 
                set'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
485
486
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
487
488
        "--distillation_filter_path", type=str, default=None,
        help="""See --train_filter_path"""
489
    )
490
491
492
493
494
    parser.add_argument(
        "--obsolete_pdbs_file_path", type=str, default=None,
        help="""Path to obsolete.dat file containing list of obsolete PDBs and 
             their replacements."""
    )
495
496
    parser.add_argument(
        "--template_release_dates_cache_path", type=str, default=None,
497
498
        help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
                files."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
499
500
    )
    parser.add_argument(
Marta's avatar
Marta committed
501
        "--use_small_bfd", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
502
503
504
        help="Whether to use a reduced version of the BFD database"
    )
    parser.add_argument(
505
506
        "--seed", type=int, default=None,
        help="Random seed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
507
    )
508
509
510
511
    parser.add_argument(
        "--deepspeed_config_path", type=str, default=None,
        help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
512
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
513
514
        "--checkpoint_every_epoch", action="store_true", default=False,
        help="""Whether to checkpoint at the end of every training epoch"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
515
516
    )
    parser.add_argument(
Marta's avatar
Marta committed
517
        "--early_stopping", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
518
519
520
521
522
523
524
525
526
527
528
        help="Whether to stop training when validation loss fails to decrease"
    )
    parser.add_argument(
        "--min_delta", type=float, default=0,
        help="""The smallest decrease in validation loss that counts as an 
                improvement for the purposes of early stopping"""
    )
    parser.add_argument(
        "--patience", type=int, default=3,
        help="Early stopping patience"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
529
530
531
532
533
    parser.add_argument(
        "--resume_from_ckpt", type=str, default=None,
        help="Path to a model checkpoint from which to restore training state"
    )
    parser.add_argument(
Marta's avatar
Marta committed
534
        "--resume_model_weights_only", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
535
536
        help="Whether to load just model weights as opposed to training state"
    )
Lucas Bickmann's avatar
Lucas Bickmann committed
537
    parser.add_argument(
538
539
        "--resume_from_jax_params", type=str, default=None,
        help="""Path to an .npz JAX parameter file with which to initialize the model"""
Lucas Bickmann's avatar
Lucas Bickmann committed
540
    )
Marta's avatar
Marta committed
541
    parser.add_argument(
542
        "--log_performance", type=bool_type, default=False,
Marta's avatar
Marta committed
543
544
        help="Measure performance"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
545
546
    parser.add_argument(
        "--wandb", action="store_true", default=False,
547
        help="Whether to log metrics to Weights & Biases"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
548
549
550
    )
    parser.add_argument(
        "--experiment_name", type=str, default=None,
551
        help="Name of the current experiment. Used for wandb logging"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
552
553
554
    )
    parser.add_argument(
        "--wandb_id", type=str, default=None,
555
        help="ID of a previous run to be resumed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
556
557
558
    )
    parser.add_argument(
        "--wandb_project", type=str, default=None,
559
        help="Name of the wandb project to which this run will belong"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
560
561
562
    )
    parser.add_argument(
        "--wandb_entity", type=str, default=None,
563
        help="wandb username or team name to which runs are attributed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
564
    )
565
566
567
568
    parser.add_argument(
        "--script_modules", type=bool_type, default=False,
        help="Whether to TorchScript eligible components of them model"
    )
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
569
    parser.add_argument(
570
        "--train_chain_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
571
572
    )
    parser.add_argument(
573
        "--distillation_chain_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
574
575
576
    )
    parser.add_argument(
        "--train_epoch_len", type=int, default=10000,
577
578
579
580
581
582
        help=(
            "The virtual length of each training epoch. Stochastic filtering "
            "of training data means that training datasets have no "
            "well-defined length. This virtual length affects frequency of "
            "validation & checkpointing (by default, one of each per epoch)."
        )
583
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
584
    parser.add_argument(
585
586
        "--log_lr", action="store_true", default=False,
        help="Whether to log the actual learning rate"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
587
    )
588
    parser.add_argument(
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
        "--config_preset", type=str, default="initial_training",
        help=(
            'Config setting. Choose e.g. "initial_training", "finetuning", '
            '"model_1", etc. By default, the actual values in the config are '
            'used.'
        )
    )
    parser.add_argument(
        "--_distillation_structure_index_path", type=str, default=None,
    )
    parser.add_argument(
        "--alignment_index_path", type=str, default=None,
        help="Training alignment index. See the README for instructions."
    )
    parser.add_argument(
        "--distillation_alignment_index_path", type=str, default=None,
        help="Distillation alignment index. See the README for instructions."
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
606
    )
607
608
609
610
611
612
613
614
615
616
617
    parser.add_argument(
        "--num_nodes", type=int, default=1,
    )
    parser.add_argument(
        "--gpus", type=int, default=1,
    )
    parser.add_argument(
        "--precision", type=str, default=None,
    )
    parser.add_argument(
        "--replace_sampler_ddp", type=bool_type, default=True,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
618
    )
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
    parser.add_argument(
        "--max_epochs", type=int, default=1,
    )
    parser.add_argument(
        "--log_every_n_steps", type=int, default=25,
    )
    parser.add_argument(
        "--num_sanity_val_steps", type=int, default=0,
    )

    #  parser = pl.Trainer.add_argparse_args(parser)
    #
    #  # Disable the initial validation pass
    #  parser.set_defaults(
    #      num_sanity_val_steps=0,
    #  )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
635

636
637
638
639
640
641
642
643
644
645
    #  # Remove some buggy/redundant arguments introduced by the Trainer
    #  remove_arguments(
    #      parser,
    #      [
    #          "--accelerator",
    #          "--resume_from_checkpoint",
    #          "--reload_dataloaders_every_epoch",
    #          "--reload_dataloaders_every_n_epochs",
    #      ]
    #  )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
646

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
647
648
    args = parser.parse_args()

649
650
    if (args.seed is None and
        ((args.gpus is not None and args.gpus > 1) or
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
651
652
653
         (args.num_nodes is not None and args.num_nodes > 1))):
        raise ValueError("For distributed training, --seed must be specified")

654
    if (str(args.precision) == "16" and args.deepspeed_config_path is not None):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
655
656
        raise ValueError("DeepSpeed and FP16 training are not compatible")

657
658
659
    if (args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
        raise ValueError(
            "Choose between loading pretrained Jax-weights and a checkpoint-path")
660

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
661
    # This re-applies the training-time filters at the beginning of every epoch
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
662
    args.reload_dataloaders_every_n_epochs = 1
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
663

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
664
    main(args)