train_openfold.py 23.1 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
import argparse
import logging
import os
4
import sys
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
5
6

import pytorch_lightning as pl
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
7
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
8
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
9
from pytorch_lightning.loggers import WandbLogger
10
from pytorch_lightning.strategies import DeepSpeedStrategy, DDPStrategy
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
11
12
13
import torch

from openfold.config import model_config
14
from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
15
from openfold.model.model import AlphaFold
16
from openfold.model.torchscript import script_preset_
17
from openfold.np import residue_constants
18
from openfold.utils.argparse_utils import remove_arguments
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
20
21
from openfold.utils.callbacks import (
    EarlyStoppingVerbose,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
22
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
23
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
24
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
25
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
26
from openfold.utils.seed import seed_everything
27
from openfold.utils.superimposition import superimpose
28
from openfold.utils.tensor_utils import tensor_tree_map
29
30
31
32
33
from openfold.utils.validation_metrics import (
    drmsd,
    gdt_ts,
    gdt_ha,
)
34
35
from openfold.utils.import_weights import (
    import_jax_weights_,
36
    import_openfold_weights_
37
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
38
from scripts.zero_to_fp32 import (
39
40
    get_fp32_state_dict_from_zero_checkpoint,
    get_global_step_from_zero_checkpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
41
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
42

Marta's avatar
Marta committed
43
44
from openfold.utils.logger import PerformanceLoggingCallback

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
45
46
47
48
49

class OpenFoldWrapper(pl.LightningModule):
    def __init__(self, config):
        super(OpenFoldWrapper, self).__init__()
        self.config = config
50
        self.model = AlphaFold(config)
51
        self.is_multimer = self.config.globals.is_multimer
52

53
        self.loss = AlphaFoldLoss(config.loss)
54

55
56
57
        self.ema = ExponentialMovingAverage(
            model=self.model, decay=config.ema.decay
        )
58

59
        self.cached_weights = None
60
        self.last_lr_step = -1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
61
62
63
64

    def forward(self, batch):
        return self.model(batch)

65
66
67
68
    def _log(self, loss_breakdown, batch, outputs, train=True):
        phase = "train" if train else "val"
        for loss_name, indiv_loss in loss_breakdown.items():
            self.log(
69
70
                f"{phase}/{loss_name}",
                indiv_loss,
71
72
73
                on_step=train, on_epoch=(not train), logger=True,
            )

74
            if (train):
75
76
77
78
79
80
81
82
                self.log(
                    f"{phase}/{loss_name}_epoch",
                    indiv_loss,
                    on_step=False, on_epoch=True, logger=True,
                )

        with torch.no_grad():
            other_metrics = self._compute_validation_metrics(
83
                batch,
84
85
86
87
                outputs,
                superimposition_metrics=(not train)
            )

88
        for k, v in other_metrics.items():
89
            self.log(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
90
91
                f"{phase}/{k}",
                torch.mean(v),
92
93
94
                on_step=False, on_epoch=True, logger=True
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
95
    def training_step(self, batch, batch_idx):
96
        if (self.ema.device != batch["aatype"].device):
97
98
            self.ema.to(batch["aatype"].device)

99
100
        ground_truth = batch.pop('gt_features', None)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
101
102
        # Run the model
        outputs = self(batch)
103

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
104
105
106
        # Remove the recycling dimension
        batch = tensor_tree_map(lambda t: t[..., -1], batch)

107
108
109
110
111
        if self.is_multimer:
            batch = multi_chain_permutation_align(out=outputs,
                                                  features=batch,
                                                  ground_truth=ground_truth)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
112
        # Compute loss
113
114
115
        loss, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
116

117
118
        # Log it
        self._log(loss_breakdown, batch, outputs)
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
119

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
120
        return loss
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
121

122
123
    def on_before_zero_grad(self, *args, **kwargs):
        self.ema.update(self.model)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
124

125
126
    def validation_step(self, batch, batch_idx):
        # At the start of validation, load the EMA weights
127
        if (self.cached_weights is None):
128
            # model.state_dict() contains references to model weights rather
129
            # than copies. Therefore, we need to clone them before calling
130
            # load_state_dict().
131
132
133
            def clone_param(t): return t.detach().clone()
            self.cached_weights = tensor_tree_map(
                clone_param, self.model.state_dict())
134
            self.model.load_state_dict(self.ema.state_dict()["params"])
135
136
137

        ground_truth = batch.pop('gt_features', None)

138
        # Run the model
139
140
        outputs = self(batch)
        batch = tensor_tree_map(lambda t: t[..., -1], batch)
141
142

        batch["use_clamped_fape"] = 0.
143
144
145
146
147
148
149

        if self.is_multimer:
            batch = multi_chain_permutation_align(out=outputs,
                                                  features=batch,
                                                  ground_truth=ground_truth)

        # Compute loss and other metrics
150
151
        _, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
152
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
153

154
        self._log(loss_breakdown, batch, outputs, train=False)
155
156

    def on_validation_epoch_end(self, _):
157
158
159
        # Restore the model weights to normal
        self.model.load_state_dict(self.cached_weights)
        self.cached_weights = None
160

161
162
163
164
165
    def _compute_validation_metrics(self,
                                    batch,
                                    outputs,
                                    superimposition_metrics=False
                                    ):
166
        metrics = {}
167

168
169
170
        gt_coords = batch["all_atom_positions"]
        pred_coords = outputs["final_atom_positions"]
        all_atom_mask = batch["all_atom_mask"]
171

172
173
174
175
176
177
178
        # This is super janky for superimposition. Fix later
        gt_coords_masked = gt_coords * all_atom_mask[..., None]
        pred_coords_masked = pred_coords * all_atom_mask[..., None]
        ca_pos = residue_constants.atom_order["CA"]
        gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
        pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
        all_atom_mask_ca = all_atom_mask[..., ca_pos]
179

180
181
182
183
184
185
186
        lddt_ca_score = lddt_ca(
            pred_coords,
            gt_coords,
            all_atom_mask,
            eps=self.config.globals.eps,
            per_residue=False,
        )
187

188
        metrics["lddt_ca"] = lddt_ca_score
189

190
191
192
        drmsd_ca_score = drmsd(
            pred_coords_masked_ca,
            gt_coords_masked_ca,
193
            mask=all_atom_mask_ca,  # still required here to compute n
194
        )
195

196
        metrics["drmsd_ca"] = drmsd_ca_score
197
198

        if (superimposition_metrics):
199
200
201
202
203
204
205
206
207
208
209
210
211
            superimposed_pred, alignment_rmsd = superimpose(
                gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
            )
            gdt_ts_score = gdt_ts(
                superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
            )
            gdt_ha_score = gdt_ha(
                superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
            )

            metrics["alignment_rmsd"] = alignment_rmsd
            metrics["gdt_ts"] = gdt_ts_score
            metrics["gdt_ha"] = gdt_ha_score
212

213
214
        return metrics

215
216
217
218
219
220
221
222
223
    def configure_optimizers(self,
                             learning_rate: float = 1e-3,
                             eps: float = 1e-5,
                             ) -> torch.optim.Adam:
        #        return torch.optim.Adam(
        #            self.model.parameters(),
        #            lr=learning_rate,
        #            eps=eps
        #        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
224
        # Ignored as long as a DeepSpeed optimizer is configured
225
        optimizer = torch.optim.Adam(
226
227
            self.model.parameters(),
            lr=learning_rate,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
228
229
            eps=eps
        )
230
231
232
233
234
235

        if self.last_lr_step != -1:
            for group in optimizer.param_groups:
                if 'initial_lr' not in group:
                    group['initial_lr'] = learning_rate

236
237
        lr_scheduler = AlphaFoldLRScheduler(
            optimizer,
238
            last_epoch=self.last_lr_step
239
        )
240

241
242
243
244
245
246
247
248
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "step",
                "name": "AlphaFoldLRScheduler",
            }
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
249

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
250
    def on_load_checkpoint(self, checkpoint):
251
        ema = checkpoint["ema"]
252
253
254
        if (not self.model.template_config.enabled):
            ema["params"] = {k: v for k,
                             v in ema["params"].items() if not "template" in k}
255
        self.ema.load_state_dict(ema)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
256

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
257
258
259
    def on_save_checkpoint(self, checkpoint):
        checkpoint["ema"] = self.ema.state_dict()

260
261
262
    def resume_last_lr_step(self, lr_step):
        self.last_lr_step = lr_step

263
264
    def load_from_jax(self, jax_path):
        model_basename = os.path.splitext(
265
266
267
            os.path.basename(
                os.path.normpath(jax_path)
            )
268
269
270
        )[0]
        model_version = "_".join(model_basename.split("_")[1:])
        import_jax_weights_(
271
            self.model, jax_path, version=model_version
272
273
        )

274

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
275
def main(args):
276
277
    if (args.seed is not None):
        seed_everything(args.seed)
278

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
279
    config = model_config(
280
281
        args.config_preset,
        train=True,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
282
        low_prec=(str(args.precision) == "16")
283
    )
284
285
    model_module = OpenFoldWrapper(config)

286
287
288
289
    if (args.resume_from_ckpt):
        if (os.path.isdir(args.resume_from_ckpt)):
            last_global_step = get_global_step_from_zero_checkpoint(
                args.resume_from_ckpt)
290
291
292
        else:
            sd = torch.load(args.resume_from_ckpt)
            last_global_step = int(sd['global_step'])
293
294
        model_module.resume_last_lr_step(last_global_step)
        logging.info("Successfully loaded last lr step...")
295
296
297
298
    if (args.resume_from_ckpt and args.resume_model_weights_only):
        if (os.path.isdir(args.resume_from_ckpt)):
            sd = get_fp32_state_dict_from_zero_checkpoint(
                args.resume_from_ckpt)
299
300
        else:
            sd = torch.load(args.resume_from_ckpt)
301
        sd = {k[len("module."):]: v for k, v in sd.items()}
302
        import_openfold_weights_(model=model_module, state_dict=sd)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
303
        logging.info("Successfully loaded model weights...")
304
    if (args.resume_from_jax_params):
Lucas Bickmann's avatar
Lucas Bickmann committed
305
        model_module.load_from_jax(args.resume_from_jax_params)
306
307
308
        logging.info(
            f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")

309
    # TorchScript components of the model
310
    if (args.script_modules):
311
        script_preset_(model_module)
312

313
314
    if "multimer" in args.config_preset:
        data_module = OpenFoldMultimerDataModule(
315
316
317
318
            config=config.data,
            batch_seed=args.seed,
            **vars(args)
        )
319
320
    else:
        data_module = OpenFoldDataModule(
321
            config=config.data,
322
323
324
            batch_seed=args.seed,
            **vars(args)
        )
325

326
327
    data_module.prepare_data()
    data_module.setup()
328

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
329
    callbacks = []
330
    if (args.checkpoint_every_epoch):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
331
        mc = ModelCheckpoint(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
332
            every_n_epochs=1,
333
334
            auto_insert_metric_name=False,
            save_top_k=-1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
335
336
337
        )
        callbacks.append(mc)

338
    if (args.early_stopping):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
339
        es = EarlyStoppingVerbose(
340
            monitor="val/lddt_ca",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
341
342
343
            min_delta=args.min_delta,
            patience=args.patience,
            verbose=False,
344
            mode="max",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
345
346
347
348
            check_finite=True,
            strict=True,
        )
        callbacks.append(es)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
349

350
    if (args.log_performance):
Marta's avatar
Marta committed
351
352
        global_batch_size = args.num_nodes * args.gpus
        perf = PerformanceLoggingCallback(
Marta's avatar
Marta committed
353
            log_file=os.path.join(args.output_dir, "performance_log.json"),
Marta's avatar
Marta committed
354
355
356
            global_batch_size=global_batch_size,
        )
        callbacks.append(perf)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
357

358
    if (args.log_lr):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
359
360
361
        lr_monitor = LearningRateMonitor(logging_interval="step")
        callbacks.append(lr_monitor)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
362
    loggers = []
363
    if (args.wandb):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
364
365
366
367
368
369
370
371
372
        wdb_logger = WandbLogger(
            name=args.experiment_name,
            save_dir=args.output_dir,
            id=args.wandb_id,
            project=args.wandb_project,
            **{"entity": args.wandb_entity}
        )
        loggers.append(wdb_logger)

373
374
    if (args.deepspeed_config_path is not None):
        strategy = DeepSpeedStrategy(
375
376
            config=args.deepspeed_config_path,
        )
377
        if (args.wandb):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
378
            wdb_logger.experiment.save(args.deepspeed_config_path)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
379
            wdb_logger.experiment.save("openfold/config.py")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
380
    elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
381
        strategy = DDPStrategy(find_unused_parameters=False)
382
383
    else:
        strategy = None
384
385

    if (args.wandb):
386
387
388
389
        freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
        os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
        wdb_logger.experiment.save(f"{freeze_path}")

390
391
392
393
394
395
396
397
398
399
400
401
402
403
    # Raw dump of all args from pl.Trainer constructor
    trainer_kws = set([
        'accelerator', 'strategy', 'devices', 'num_nodes', 'precision', 'logger', 'callbacks', 'fast_dev_run', 'max_epochs', 'min_epochs', 'max_steps', 'min_steps', 'max_tim', 'limit_train_batches', 'limit_val_batches', 'limit_test_batches', 'limit_predict_batches', 'overfit_batches', 'val_check_interval', 'check_val_every_n_epoch', 'num_sanity_val_steps', 'log_every_n_steps', 'enable_checkpointing', 'enable_progress_bar', 'enable_model_summary', 'accumulate_grad_batches', 'gradient_clip_val', 'gradient_clip_algorithm', 'deterministic', 'benchmark', 'inference_mode', 'use_distributed_sampler', 'profiler', 'detect_anomaly', 'barebones', 'plugins', 'sync_batchnorm', 'reload_dataloaders_every_n_epochs', 'default_root_dir',
    ])
    trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws}
    trainer_args.update({
        'default_root_dir': args.output_dir,
        'strategy': strategy,
        'callbacks': callbacks,
        'logger': loggers,
    })
    trainer = pl.Trainer(**trainer_args)

    if (args.resume_model_weights_only):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
404
405
406
407
408
        ckpt_path = None
    else:
        ckpt_path = args.resume_from_ckpt

    trainer.fit(
409
        model_module,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
410
411
        datamodule=data_module,
        ckpt_path=ckpt_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
412
413
414
    )


Marta's avatar
Marta committed
415
416
417
418
419
420
421
422
423
424
def bool_type(bool_str: str):
    bool_str_lower = bool_str.lower()
    if bool_str_lower in ('false', 'f', 'no', 'n', '0'):
        return False
    elif bool_str_lower in ('true', 't', 'yes', 'y', '1'):
        return True
    else:
        raise ValueError(f'Cannot interpret {bool_str} as bool')


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
425
426
427
428
429
430
431
432
433
434
435
436
437
438
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "train_data_dir", type=str,
        help="Directory containing training mmCIF files"
    )
    parser.add_argument(
        "train_alignment_dir", type=str,
        help="Directory containing precomputed training alignments"
    )
    parser.add_argument(
        "template_mmcif_dir", type=str,
        help="Directory containing mmCIF files to search for templates"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
439
440
441
442
443
    parser.add_argument(
        "output_dir", type=str,
        help='''Directory in which to output checkpoints, logs, etc. Ignored
                if not on rank 0'''
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
444
445
    parser.add_argument(
        "max_template_date", type=str,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
446
447
        help='''Cutoff for all templates. In training mode, templates are also 
                filtered by the release date of the target'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
448
    )
449
450
    parser.add_argument(
        "--train_mmcif_data_cache_path", type=str, default=None,
451
452
        help="Path to the json file which records all the information of mmcif structures used during training"
    )
453
    parser.add_argument(
454
        "--use_single_seq_mode", type=str, default=False,
455
        help="Use single sequence embeddings instead of MSAs."
456
    )
457
458
459
460
461
462
463
464
    parser.add_argument(
        "--distillation_data_dir", type=str, default=None,
        help="Directory containing training PDB files"
    )
    parser.add_argument(
        "--distillation_alignment_dir", type=str, default=None,
        help="Directory containing precomputed distillation alignments"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
465
466
467
468
469
470
471
472
    parser.add_argument(
        "--val_data_dir", type=str, default=None,
        help="Directory containing validation mmCIF files"
    )
    parser.add_argument(
        "--val_alignment_dir", type=str, default=None,
        help="Directory containing precomputed validation alignments"
    )
473
474
    parser.add_argument(
        "--val_mmcif_data_cache_path", type=str, default=None,
Dingquan Yu's avatar
Dingquan Yu committed
475
        help="path to the json file which records all the information of mmcif structures used during validation"
476
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
477
478
479
480
481
    parser.add_argument(
        "--kalign_binary_path", type=str, default='/usr/bin/kalign',
        help="Path to the kalign binary"
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
482
483
484
485
        "--train_filter_path", type=str, default=None,
        help='''Optional path to a text file containing names of training
                examples to include, one per line. Used to filter the training 
                set'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
486
487
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
488
489
        "--distillation_filter_path", type=str, default=None,
        help="""See --train_filter_path"""
490
    )
491
492
493
494
495
    parser.add_argument(
        "--obsolete_pdbs_file_path", type=str, default=None,
        help="""Path to obsolete.dat file containing list of obsolete PDBs and 
             their replacements."""
    )
496
497
    parser.add_argument(
        "--template_release_dates_cache_path", type=str, default=None,
498
499
        help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
                files."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
500
501
    )
    parser.add_argument(
Marta's avatar
Marta committed
502
        "--use_small_bfd", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
503
504
505
        help="Whether to use a reduced version of the BFD database"
    )
    parser.add_argument(
506
507
        "--seed", type=int, default=None,
        help="Random seed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
508
    )
509
510
511
512
    parser.add_argument(
        "--deepspeed_config_path", type=str, default=None,
        help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
513
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
514
515
        "--checkpoint_every_epoch", action="store_true", default=False,
        help="""Whether to checkpoint at the end of every training epoch"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
516
517
    )
    parser.add_argument(
Marta's avatar
Marta committed
518
        "--early_stopping", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
519
520
521
522
523
524
525
526
527
528
529
        help="Whether to stop training when validation loss fails to decrease"
    )
    parser.add_argument(
        "--min_delta", type=float, default=0,
        help="""The smallest decrease in validation loss that counts as an 
                improvement for the purposes of early stopping"""
    )
    parser.add_argument(
        "--patience", type=int, default=3,
        help="Early stopping patience"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
530
531
532
533
534
    parser.add_argument(
        "--resume_from_ckpt", type=str, default=None,
        help="Path to a model checkpoint from which to restore training state"
    )
    parser.add_argument(
Marta's avatar
Marta committed
535
        "--resume_model_weights_only", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
536
537
        help="Whether to load just model weights as opposed to training state"
    )
Lucas Bickmann's avatar
Lucas Bickmann committed
538
    parser.add_argument(
539
540
        "--resume_from_jax_params", type=str, default=None,
        help="""Path to an .npz JAX parameter file with which to initialize the model"""
Lucas Bickmann's avatar
Lucas Bickmann committed
541
    )
Marta's avatar
Marta committed
542
    parser.add_argument(
543
        "--log_performance", type=bool_type, default=False,
Marta's avatar
Marta committed
544
545
        help="Measure performance"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
546
547
    parser.add_argument(
        "--wandb", action="store_true", default=False,
548
        help="Whether to log metrics to Weights & Biases"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
549
550
551
    )
    parser.add_argument(
        "--experiment_name", type=str, default=None,
552
        help="Name of the current experiment. Used for wandb logging"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
553
554
555
    )
    parser.add_argument(
        "--wandb_id", type=str, default=None,
556
        help="ID of a previous run to be resumed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
557
558
559
    )
    parser.add_argument(
        "--wandb_project", type=str, default=None,
560
        help="Name of the wandb project to which this run will belong"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
561
562
563
    )
    parser.add_argument(
        "--wandb_entity", type=str, default=None,
564
        help="wandb username or team name to which runs are attributed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
565
    )
566
567
568
569
    parser.add_argument(
        "--script_modules", type=bool_type, default=False,
        help="Whether to TorchScript eligible components of them model"
    )
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
570
    parser.add_argument(
571
        "--train_chain_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
572
573
    )
    parser.add_argument(
574
        "--distillation_chain_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
575
576
577
    )
    parser.add_argument(
        "--train_epoch_len", type=int, default=10000,
578
579
580
581
582
583
        help=(
            "The virtual length of each training epoch. Stochastic filtering "
            "of training data means that training datasets have no "
            "well-defined length. This virtual length affects frequency of "
            "validation & checkpointing (by default, one of each per epoch)."
        )
584
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
585
    parser.add_argument(
586
587
        "--log_lr", action="store_true", default=False,
        help="Whether to log the actual learning rate"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
588
    )
589
    parser.add_argument(
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
        "--config_preset", type=str, default="initial_training",
        help=(
            'Config setting. Choose e.g. "initial_training", "finetuning", '
            '"model_1", etc. By default, the actual values in the config are '
            'used.'
        )
    )
    parser.add_argument(
        "--_distillation_structure_index_path", type=str, default=None,
    )
    parser.add_argument(
        "--alignment_index_path", type=str, default=None,
        help="Training alignment index. See the README for instructions."
    )
    parser.add_argument(
        "--distillation_alignment_index_path", type=str, default=None,
        help="Distillation alignment index. See the README for instructions."
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
607
    )
608
609
610
611
612
613
614
615
616
617
618
    parser.add_argument(
        "--num_nodes", type=int, default=1,
    )
    parser.add_argument(
        "--gpus", type=int, default=1,
    )
    parser.add_argument(
        "--precision", type=str, default=None,
    )
    parser.add_argument(
        "--replace_sampler_ddp", type=bool_type, default=True,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
619
    )
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
    parser.add_argument(
        "--max_epochs", type=int, default=1,
    )
    parser.add_argument(
        "--log_every_n_steps", type=int, default=25,
    )
    parser.add_argument(
        "--num_sanity_val_steps", type=int, default=0,
    )

    #  parser = pl.Trainer.add_argparse_args(parser)
    #
    #  # Disable the initial validation pass
    #  parser.set_defaults(
    #      num_sanity_val_steps=0,
    #  )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
636

637
638
639
640
641
642
643
644
645
646
    #  # Remove some buggy/redundant arguments introduced by the Trainer
    #  remove_arguments(
    #      parser,
    #      [
    #          "--accelerator",
    #          "--resume_from_checkpoint",
    #          "--reload_dataloaders_every_epoch",
    #          "--reload_dataloaders_every_n_epochs",
    #      ]
    #  )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
647

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
648
649
    args = parser.parse_args()

650
651
    if (args.seed is None and
        ((args.gpus is not None and args.gpus > 1) or
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
652
653
654
         (args.num_nodes is not None and args.num_nodes > 1))):
        raise ValueError("For distributed training, --seed must be specified")

655
    if (str(args.precision) == "16" and args.deepspeed_config_path is not None):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
656
657
        raise ValueError("DeepSpeed and FP16 training are not compatible")

658
659
660
    if (args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
        raise ValueError(
            "Choose between loading pretrained Jax-weights and a checkpoint-path")
661

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
662
    # This re-applies the training-time filters at the beginning of every epoch
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
663
    args.reload_dataloaders_every_n_epochs = 1
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
664

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
665
    main(args)