train_openfold.py 20.1 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
import argparse
import logging
import os
4
import random
5
import sys
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
6
7
import time

8
import numpy as np
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
9
import pytorch_lightning as pl
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
10
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
11
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
12
from pytorch_lightning.loggers import WandbLogger
13
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
14
from pytorch_lightning.plugins.environments import SLURMEnvironment
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
15
16
17
import torch

from openfold.config import model_config
18
19
from openfold.data.data_modules import (
    OpenFoldDataModule,
20
    DummyDataLoader,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
21
)
22
from openfold.model.model import AlphaFold
23
from openfold.model.torchscript import script_preset_
24
25
from openfold.np import residue_constants
from openfold.utils.argparse import remove_arguments
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
26
27
28
from openfold.utils.callbacks import (
    EarlyStoppingVerbose,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
29
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
30
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
31
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
32
from openfold.utils.seed import seed_everything
33
from openfold.utils.superimposition import superimpose
34
from openfold.utils.tensor_utils import tensor_tree_map
35
36
37
38
39
from openfold.utils.validation_metrics import (
    drmsd,
    gdt_ts,
    gdt_ha,
)
40
41
42
from openfold.utils.import_weights import (
    import_jax_weights_,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
43
from scripts.zero_to_fp32 import (
44
45
    get_fp32_state_dict_from_zero_checkpoint,
    get_global_step_from_zero_checkpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
46
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
47

Marta's avatar
Marta committed
48
49
from openfold.utils.logger import PerformanceLoggingCallback

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
50
51
52
53
54

class OpenFoldWrapper(pl.LightningModule):
    def __init__(self, config):
        super(OpenFoldWrapper, self).__init__()
        self.config = config
55
        self.model = AlphaFold(config)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
56
        self.loss = AlphaFoldLoss(config.loss)
57
58
59
        self.ema = ExponentialMovingAverage(
            model=self.model, decay=config.ema.decay
        )
60
61
        
        self.cached_weights = None
62
        self.last_lr_step = -1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
63
64
65
66

    def forward(self, batch):
        return self.model(batch)

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    def _log(self, loss_breakdown, batch, outputs, train=True):
        phase = "train" if train else "val"
        for loss_name, indiv_loss in loss_breakdown.items():
            self.log(
                f"{phase}/{loss_name}", 
                indiv_loss, 
                on_step=train, on_epoch=(not train), logger=True,
            )

            if(train):
                self.log(
                    f"{phase}/{loss_name}_epoch",
                    indiv_loss,
                    on_step=False, on_epoch=True, logger=True,
                )

        with torch.no_grad():
            other_metrics = self._compute_validation_metrics(
                batch, 
                outputs,
                superimposition_metrics=(not train)
            )

        for k,v in other_metrics.items():
            self.log(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
92
93
                f"{phase}/{k}",
                torch.mean(v),
94
95
96
                on_step=False, on_epoch=True, logger=True
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
97
    def training_step(self, batch, batch_idx):
98
99
100
        if(self.ema.device != batch["aatype"].device):
            self.ema.to(batch["aatype"].device)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
101
102
        # Run the model
        outputs = self(batch)
103

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
104
105
106
107
        # Remove the recycling dimension
        batch = tensor_tree_map(lambda t: t[..., -1], batch)

        # Compute loss
108
109
110
        loss, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
111

112
113
        # Log it
        self._log(loss_breakdown, batch, outputs)
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
114

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
115
        return loss
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
116

117
118
    def on_before_zero_grad(self, *args, **kwargs):
        self.ema.update(self.model)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
119

120
121
122
    def validation_step(self, batch, batch_idx):
        # At the start of validation, load the EMA weights
        if(self.cached_weights is None):
123
124
125
126
127
            # model.state_dict() contains references to model weights rather
            # than copies. Therefore, we need to clone them before calling 
            # load_state_dict().
            clone_param = lambda t: t.detach().clone()
            self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
128
            self.model.load_state_dict(self.ema.state_dict()["params"])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
129
       
130
        # Run the model
131
132
        outputs = self(batch)
        batch = tensor_tree_map(lambda t: t[..., -1], batch)
133
134
135
136
137

        # Compute loss and other metrics
        batch["use_clamped_fape"] = 0.
        _, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
138
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
139

140
141
        self._log(loss_breakdown, batch, outputs, train=False)
        
142
143
144
145
    def validation_epoch_end(self, _):
        # Restore the model weights to normal
        self.model.load_state_dict(self.cached_weights)
        self.cached_weights = None
146

147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    def _compute_validation_metrics(self, 
        batch, 
        outputs, 
        superimposition_metrics=False
    ):
        metrics = {}
        
        gt_coords = batch["all_atom_positions"]
        pred_coords = outputs["final_atom_positions"]
        all_atom_mask = batch["all_atom_mask"]
    
        # This is super janky for superimposition. Fix later
        gt_coords_masked = gt_coords * all_atom_mask[..., None]
        pred_coords_masked = pred_coords * all_atom_mask[..., None]
        ca_pos = residue_constants.atom_order["CA"]
        gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
        pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
        all_atom_mask_ca = all_atom_mask[..., ca_pos]
    
        lddt_ca_score = lddt_ca(
            pred_coords,
            gt_coords,
            all_atom_mask,
            eps=self.config.globals.eps,
            per_residue=False,
        )
   
        metrics["lddt_ca"] = lddt_ca_score
   
        drmsd_ca_score = drmsd(
            pred_coords_masked_ca,
            gt_coords_masked_ca,
            mask=all_atom_mask_ca, # still required here to compute n
        )
   
        metrics["drmsd_ca"] = drmsd_ca_score
    
        if(superimposition_metrics):
            superimposed_pred, alignment_rmsd = superimpose(
                gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
            )
            gdt_ts_score = gdt_ts(
                superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
            )
            gdt_ha_score = gdt_ha(
                superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
            )

            metrics["alignment_rmsd"] = alignment_rmsd
            metrics["gdt_ts"] = gdt_ts_score
            metrics["gdt_ha"] = gdt_ha_score
    
        return metrics

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
201
202
    def configure_optimizers(self, 
        learning_rate: float = 1e-3,
203
        eps: float = 1e-5,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
204
    ) -> torch.optim.Adam:
205
206
207
208
209
#        return torch.optim.Adam(
#            self.model.parameters(),
#            lr=learning_rate,
#            eps=eps
#        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
210
        # Ignored as long as a DeepSpeed optimizer is configured
211
        optimizer = torch.optim.Adam(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
212
213
214
215
            self.model.parameters(), 
            lr=learning_rate, 
            eps=eps
        )
216
217
218
219
220
221

        if self.last_lr_step != -1:
            for group in optimizer.param_groups:
                if 'initial_lr' not in group:
                    group['initial_lr'] = learning_rate

222
223
224
        lr_scheduler = AlphaFoldLRScheduler(
            optimizer,
        )
225

226
227
228
229
230
231
232
233
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "step",
                "name": "AlphaFoldLRScheduler",
            }
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
234

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
235
    def on_load_checkpoint(self, checkpoint):
236
237
238
239
        ema = checkpoint["ema"]
        if(not self.model.template_config.enabled):
            ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k}
        self.ema.load_state_dict(ema)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
240

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
241
242
243
    def on_save_checkpoint(self, checkpoint):
        checkpoint["ema"] = self.ema.state_dict()

244
245
246
    def resume_last_lr_step(self, lr_step):
        self.last_lr_step = lr_step

247
248
249
250
251
252
253
254
255
256
257
    def load_from_jax(self, jax_path):
        model_basename = os.path.splitext(
                os.path.basename(
                    os.path.normpath(jax_path)
                )
        )[0]
        model_version = "_".join(model_basename.split("_")[1:])
        import_jax_weights_(
                self.model, jax_path, version=model_version
        )

258

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
259
def main(args):
260
261
262
    if(args.seed is not None):
        seed_everything(args.seed) 

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
263
    config = model_config(
264
        args.config_preset, 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
265
        train=True, 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
266
        low_prec=(str(args.precision) == "16")
267
    ) 
Gustaf's avatar
Gustaf committed
268
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
269
    model_module = OpenFoldWrapper(config)
270
    if(args.resume_from_ckpt):
271
272
273
274
275
        if(os.path.isdir(args.resume_from_ckpt)):  
            last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
        else:
            sd = torch.load(args.resume_from_ckpt)
            last_global_step = int(sd['global_step'])
276
277
        model_module.resume_last_lr_step(last_global_step)
        logging.info("Successfully loaded last lr step...")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
278
    if(args.resume_from_ckpt and args.resume_model_weights_only):
279
280
281
282
        if(os.path.isdir(args.resume_from_ckpt)):
            sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
        else:
            sd = torch.load(args.resume_from_ckpt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
283
284
285
        sd = {k[len("module."):]:v for k,v in sd.items()}
        model_module.load_state_dict(sd)
        logging.info("Successfully loaded model weights...")
Lucas Bickmann's avatar
Lucas Bickmann committed
286
287
288
    if(args.resume_from_jax_params):
        model_module.load_from_jax(args.resume_from_jax_params)
        logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
289
 
290
    # TorchScript components of the model
291
292
    if(args.script_modules):
        script_preset_(model_module)
293

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
294
    #data_module = DummyDataLoader("new_batch.pickle")
295
296
297
298
299
    data_module = OpenFoldDataModule(
        config=config.data, 
        batch_seed=args.seed,
        **vars(args)
    )
300

301
302
    data_module.prepare_data()
    data_module.setup()
303
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
304
    callbacks = []
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
305
    if(args.checkpoint_every_epoch):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
306
        mc = ModelCheckpoint(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
307
            every_n_epochs=1,
308
309
            auto_insert_metric_name=False,
            save_top_k=-1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
310
311
312
313
314
        )
        callbacks.append(mc)

    if(args.early_stopping):
        es = EarlyStoppingVerbose(
315
            monitor="val/lddt_ca",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
316
317
318
            min_delta=args.min_delta,
            patience=args.patience,
            verbose=False,
319
            mode="max",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
320
321
322
323
            check_finite=True,
            strict=True,
        )
        callbacks.append(es)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
324

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
325
    if(args.log_performance):
Marta's avatar
Marta committed
326
327
        global_batch_size = args.num_nodes * args.gpus
        perf = PerformanceLoggingCallback(
Marta's avatar
Marta committed
328
            log_file=os.path.join(args.output_dir, "performance_log.json"),
Marta's avatar
Marta committed
329
330
331
            global_batch_size=global_batch_size,
        )
        callbacks.append(perf)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
332

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
333
334
335
336
    if(args.log_lr):
        lr_monitor = LearningRateMonitor(logging_interval="step")
        callbacks.append(lr_monitor)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
337
338
339
340
341
342
343
344
345
346
347
    loggers = []
    if(args.wandb):
        wdb_logger = WandbLogger(
            name=args.experiment_name,
            save_dir=args.output_dir,
            id=args.wandb_id,
            project=args.wandb_project,
            **{"entity": args.wandb_entity}
        )
        loggers.append(wdb_logger)

348
    if(args.deepspeed_config_path is not None):
349
350
351
        strategy = DeepSpeedPlugin(
            config=args.deepspeed_config_path,
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
352
353
        if(args.wandb):
            wdb_logger.experiment.save(args.deepspeed_config_path)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
354
            wdb_logger.experiment.save("openfold/config.py")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
355
    elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
356
        strategy = DDPPlugin(find_unused_parameters=False)
357
358
    else:
        strategy = None
359
360
361
362
363
364
 
    if(args.wandb):
        freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
        os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
        wdb_logger.experiment.save(f"{freeze_path}")

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
365
366
    trainer = pl.Trainer.from_argparse_args(
        args,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
367
        default_root_dir=args.output_dir,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
368
        strategy=strategy,
Marta's avatar
Marta committed
369
        callbacks=callbacks,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
370
        logger=loggers,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
371
372
373
374
375
376
377
378
379
380
381
    )

    if(args.resume_model_weights_only):
        ckpt_path = None
    else:
        ckpt_path = args.resume_from_ckpt

    trainer.fit(
        model_module, 
        datamodule=data_module,
        ckpt_path=ckpt_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
382
383
384
    )


Marta's avatar
Marta committed
385
386
387
388
389
390
391
392
393
394
def bool_type(bool_str: str):
    bool_str_lower = bool_str.lower()
    if bool_str_lower in ('false', 'f', 'no', 'n', '0'):
        return False
    elif bool_str_lower in ('true', 't', 'yes', 'y', '1'):
        return True
    else:
        raise ValueError(f'Cannot interpret {bool_str} as bool')


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "train_data_dir", type=str,
        help="Directory containing training mmCIF files"
    )
    parser.add_argument(
        "train_alignment_dir", type=str,
        help="Directory containing precomputed training alignments"
    )
    parser.add_argument(
        "template_mmcif_dir", type=str,
        help="Directory containing mmCIF files to search for templates"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
409
410
411
412
413
    parser.add_argument(
        "output_dir", type=str,
        help='''Directory in which to output checkpoints, logs, etc. Ignored
                if not on rank 0'''
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
414
415
    parser.add_argument(
        "max_template_date", type=str,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
416
417
        help='''Cutoff for all templates. In training mode, templates are also 
                filtered by the release date of the target'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
418
    )
419
420
421
422
    parser.add_argument(
        "--use_single_seq_mode", type=str, default=None,
        help="Use single sequence embeddings instead of MSAs."
    )
423
424
425
426
427
428
429
430
    parser.add_argument(
        "--distillation_data_dir", type=str, default=None,
        help="Directory containing training PDB files"
    )
    parser.add_argument(
        "--distillation_alignment_dir", type=str, default=None,
        help="Directory containing precomputed distillation alignments"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
431
432
433
434
435
436
437
438
439
440
441
442
443
    parser.add_argument(
        "--val_data_dir", type=str, default=None,
        help="Directory containing validation mmCIF files"
    )
    parser.add_argument(
        "--val_alignment_dir", type=str, default=None,
        help="Directory containing precomputed validation alignments"
    )
    parser.add_argument(
        "--kalign_binary_path", type=str, default='/usr/bin/kalign',
        help="Path to the kalign binary"
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
444
445
446
447
        "--train_filter_path", type=str, default=None,
        help='''Optional path to a text file containing names of training
                examples to include, one per line. Used to filter the training 
                set'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
448
449
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
450
451
        "--distillation_filter_path", type=str, default=None,
        help="""See --train_filter_path"""
452
    )
453
454
455
456
457
    parser.add_argument(
        "--obsolete_pdbs_file_path", type=str, default=None,
        help="""Path to obsolete.dat file containing list of obsolete PDBs and 
             their replacements."""
    )
458
459
    parser.add_argument(
        "--template_release_dates_cache_path", type=str, default=None,
460
461
        help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
                files."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
462
463
    )
    parser.add_argument(
Marta's avatar
Marta committed
464
        "--use_small_bfd", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
465
466
467
        help="Whether to use a reduced version of the BFD database"
    )
    parser.add_argument(
468
469
        "--seed", type=int, default=None,
        help="Random seed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
470
    )
471
472
473
474
    parser.add_argument(
        "--deepspeed_config_path", type=str, default=None,
        help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
475
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
476
477
        "--checkpoint_every_epoch", action="store_true", default=False,
        help="""Whether to checkpoint at the end of every training epoch"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
478
479
    )
    parser.add_argument(
Marta's avatar
Marta committed
480
        "--early_stopping", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
481
482
483
484
485
486
487
488
489
490
491
        help="Whether to stop training when validation loss fails to decrease"
    )
    parser.add_argument(
        "--min_delta", type=float, default=0,
        help="""The smallest decrease in validation loss that counts as an 
                improvement for the purposes of early stopping"""
    )
    parser.add_argument(
        "--patience", type=int, default=3,
        help="Early stopping patience"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
492
493
494
495
496
    parser.add_argument(
        "--resume_from_ckpt", type=str, default=None,
        help="Path to a model checkpoint from which to restore training state"
    )
    parser.add_argument(
Marta's avatar
Marta committed
497
        "--resume_model_weights_only", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
498
499
        help="Whether to load just model weights as opposed to training state"
    )
Lucas Bickmann's avatar
Lucas Bickmann committed
500
    parser.add_argument(
501
502
        "--resume_from_jax_params", type=str, default=None,
        help="""Path to an .npz JAX parameter file with which to initialize the model"""
Lucas Bickmann's avatar
Lucas Bickmann committed
503
    )
Marta's avatar
Marta committed
504
    parser.add_argument(
505
        "--log_performance", type=bool_type, default=False,
Marta's avatar
Marta committed
506
507
        help="Measure performance"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
508
509
    parser.add_argument(
        "--wandb", action="store_true", default=False,
510
        help="Whether to log metrics to Weights & Biases"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
511
512
513
    )
    parser.add_argument(
        "--experiment_name", type=str, default=None,
514
        help="Name of the current experiment. Used for wandb logging"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
515
516
517
    )
    parser.add_argument(
        "--wandb_id", type=str, default=None,
518
        help="ID of a previous run to be resumed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
519
520
521
    )
    parser.add_argument(
        "--wandb_project", type=str, default=None,
522
        help="Name of the wandb project to which this run will belong"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
523
524
525
    )
    parser.add_argument(
        "--wandb_entity", type=str, default=None,
526
        help="wandb username or team name to which runs are attributed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
527
    )
528
529
530
531
    parser.add_argument(
        "--script_modules", type=bool_type, default=False,
        help="Whether to TorchScript eligible components of them model"
    )
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
532
    parser.add_argument(
533
        "--train_chain_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
534
535
    )
    parser.add_argument(
536
        "--distillation_chain_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
537
538
539
    )
    parser.add_argument(
        "--train_epoch_len", type=int, default=10000,
540
541
542
543
544
545
        help=(
            "The virtual length of each training epoch. Stochastic filtering "
            "of training data means that training datasets have no "
            "well-defined length. This virtual length affects frequency of "
            "validation & checkpointing (by default, one of each per epoch)."
        )
546
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
547
    parser.add_argument(
548
549
        "--log_lr", action="store_true", default=False,
        help="Whether to log the actual learning rate"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
550
    )
551
    parser.add_argument(
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
        "--config_preset", type=str, default="initial_training",
        help=(
            'Config setting. Choose e.g. "initial_training", "finetuning", '
            '"model_1", etc. By default, the actual values in the config are '
            'used.'
        )
    )
    parser.add_argument(
        "--_distillation_structure_index_path", type=str, default=None,
    )
    parser.add_argument(
        "--alignment_index_path", type=str, default=None,
        help="Training alignment index. See the README for instructions."
    )
    parser.add_argument(
        "--distillation_alignment_index_path", type=str, default=None,
        help="Distillation alignment index. See the README for instructions."
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
569
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
570
    parser = pl.Trainer.add_argparse_args(parser)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
571
572
   
    # Disable the initial validation pass
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
573
574
575
576
    parser.set_defaults(
        num_sanity_val_steps=0,
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
577
    # Remove some buggy/redundant arguments introduced by the Trainer
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
578
579
580
581
582
    remove_arguments(
        parser, 
        [
            "--accelerator", 
            "--resume_from_checkpoint",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
583
584
            "--reload_dataloaders_every_epoch",
            "--reload_dataloaders_every_n_epochs",
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
585
586
        ]
    ) 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
587

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
588
589
    args = parser.parse_args()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
590
591
592
593
594
    if(args.seed is None and 
        ((args.gpus is not None and args.gpus > 1) or 
         (args.num_nodes is not None and args.num_nodes > 1))):
        raise ValueError("For distributed training, --seed must be specified")

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
595
    if(str(args.precision) == "16" and args.deepspeed_config_path is not None):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
596
597
        raise ValueError("DeepSpeed and FP16 training are not compatible")

Lucas Bickmann's avatar
Lucas Bickmann committed
598
    if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
599
600
        raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path")

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
601
    # This re-applies the training-time filters at the beginning of every epoch
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
602
    args.reload_dataloaders_every_n_epochs = 1
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
603

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
604
    main(args)