train_openfold.py 21.1 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
import argparse
import logging
import os
4
import sys
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
5
6

import pytorch_lightning as pl
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
7
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
8
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
9
from pytorch_lightning.loggers import WandbLogger
10
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
11
12
13
import torch

from openfold.config import model_config
14
from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
15
from openfold.model.model import AlphaFold
16
from openfold.model.torchscript import script_preset_
17
from openfold.np import residue_constants
18
from openfold.utils.argparse_utils import remove_arguments
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
20
21
from openfold.utils.callbacks import (
    EarlyStoppingVerbose,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
22
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
23
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
24
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
25
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
26
from openfold.utils.seed import seed_everything
27
from openfold.utils.superimposition import superimpose
28
from openfold.utils.tensor_utils import tensor_tree_map
29
30
31
32
33
from openfold.utils.validation_metrics import (
    drmsd,
    gdt_ts,
    gdt_ha,
)
34
35
36
from openfold.utils.import_weights import (
    import_jax_weights_,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
37
from scripts.zero_to_fp32 import (
38
39
    get_fp32_state_dict_from_zero_checkpoint,
    get_global_step_from_zero_checkpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
40
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
41

Marta's avatar
Marta committed
42
43
from openfold.utils.logger import PerformanceLoggingCallback

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
44
45
46
47
48

class OpenFoldWrapper(pl.LightningModule):
    def __init__(self, config):
        super(OpenFoldWrapper, self).__init__()
        self.config = config
49
        self.model = AlphaFold(config)
50
        self.is_multimer = self.config.globals.is_multimer
51

52
        self.loss = AlphaFoldLoss(config.loss)
53

54
55
56
        self.ema = ExponentialMovingAverage(
            model=self.model, decay=config.ema.decay
        )
57
58
        
        self.cached_weights = None
59
        self.last_lr_step = -1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
60
61
62
63

    def forward(self, batch):
        return self.model(batch)

64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    def _log(self, loss_breakdown, batch, outputs, train=True):
        phase = "train" if train else "val"
        for loss_name, indiv_loss in loss_breakdown.items():
            self.log(
                f"{phase}/{loss_name}", 
                indiv_loss, 
                on_step=train, on_epoch=(not train), logger=True,
            )

            if(train):
                self.log(
                    f"{phase}/{loss_name}_epoch",
                    indiv_loss,
                    on_step=False, on_epoch=True, logger=True,
                )

        with torch.no_grad():
            other_metrics = self._compute_validation_metrics(
                batch, 
                outputs,
                superimposition_metrics=(not train)
            )

        for k,v in other_metrics.items():
            self.log(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
89
90
                f"{phase}/{k}",
                torch.mean(v),
91
92
93
                on_step=False, on_epoch=True, logger=True
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
94
    def training_step(self, batch, batch_idx):
95
96
97
        if(self.ema.device != batch["aatype"].device):
            self.ema.to(batch["aatype"].device)

98
99
        ground_truth = batch.pop('gt_features', None)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
100
101
        # Run the model
        outputs = self(batch)
102

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
103
104
105
        # Remove the recycling dimension
        batch = tensor_tree_map(lambda t: t[..., -1], batch)

106
107
108
109
110
        if self.is_multimer:
            batch = multi_chain_permutation_align(out=outputs,
                                                  features=batch,
                                                  ground_truth=ground_truth)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
111
        # Compute loss
112
113
114
        loss, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
115

116
117
        # Log it
        self._log(loss_breakdown, batch, outputs)
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
118

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
119
        return loss
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
120

121
122
    def on_before_zero_grad(self, *args, **kwargs):
        self.ema.update(self.model)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
123

124
125
126
    def validation_step(self, batch, batch_idx):
        # At the start of validation, load the EMA weights
        if(self.cached_weights is None):
127
128
129
130
131
            # model.state_dict() contains references to model weights rather
            # than copies. Therefore, we need to clone them before calling 
            # load_state_dict().
            clone_param = lambda t: t.detach().clone()
            self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
132
            self.model.load_state_dict(self.ema.state_dict()["params"])
133
134
135

        ground_truth = batch.pop('gt_features', None)

136
        # Run the model
137
138
        outputs = self(batch)
        batch = tensor_tree_map(lambda t: t[..., -1], batch)
139
140

        batch["use_clamped_fape"] = 0.
141
142
143
144
145
146
147

        if self.is_multimer:
            batch = multi_chain_permutation_align(out=outputs,
                                                  features=batch,
                                                  ground_truth=ground_truth)

        # Compute loss and other metrics
148
149
        _, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
150
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
151

152
153
        self._log(loss_breakdown, batch, outputs, train=False)
        
154
155
156
157
    def validation_epoch_end(self, _):
        # Restore the model weights to normal
        self.model.load_state_dict(self.cached_weights)
        self.cached_weights = None
158

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    def _compute_validation_metrics(self, 
        batch, 
        outputs, 
        superimposition_metrics=False
    ):
        metrics = {}
        
        gt_coords = batch["all_atom_positions"]
        pred_coords = outputs["final_atom_positions"]
        all_atom_mask = batch["all_atom_mask"]
    
        # This is super janky for superimposition. Fix later
        gt_coords_masked = gt_coords * all_atom_mask[..., None]
        pred_coords_masked = pred_coords * all_atom_mask[..., None]
        ca_pos = residue_constants.atom_order["CA"]
        gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
        pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
        all_atom_mask_ca = all_atom_mask[..., ca_pos]
    
        lddt_ca_score = lddt_ca(
            pred_coords,
            gt_coords,
            all_atom_mask,
            eps=self.config.globals.eps,
            per_residue=False,
        )
   
        metrics["lddt_ca"] = lddt_ca_score
   
        drmsd_ca_score = drmsd(
            pred_coords_masked_ca,
            gt_coords_masked_ca,
            mask=all_atom_mask_ca, # still required here to compute n
        )
   
        metrics["drmsd_ca"] = drmsd_ca_score
    
        if(superimposition_metrics):
            superimposed_pred, alignment_rmsd = superimpose(
                gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
            )
            gdt_ts_score = gdt_ts(
                superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
            )
            gdt_ha_score = gdt_ha(
                superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
            )

            metrics["alignment_rmsd"] = alignment_rmsd
            metrics["gdt_ts"] = gdt_ts_score
            metrics["gdt_ha"] = gdt_ha_score
    
        return metrics

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
213
214
    def configure_optimizers(self, 
        learning_rate: float = 1e-3,
215
        eps: float = 1e-5,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
216
    ) -> torch.optim.Adam:
217
218
219
220
221
#        return torch.optim.Adam(
#            self.model.parameters(),
#            lr=learning_rate,
#            eps=eps
#        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
222
        # Ignored as long as a DeepSpeed optimizer is configured
223
        optimizer = torch.optim.Adam(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
224
225
226
227
            self.model.parameters(), 
            lr=learning_rate, 
            eps=eps
        )
228
229
230
231
232
233

        if self.last_lr_step != -1:
            for group in optimizer.param_groups:
                if 'initial_lr' not in group:
                    group['initial_lr'] = learning_rate

234
235
236
        lr_scheduler = AlphaFoldLRScheduler(
            optimizer,
        )
237

238
239
240
241
242
243
244
245
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "step",
                "name": "AlphaFoldLRScheduler",
            }
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
246

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
247
    def on_load_checkpoint(self, checkpoint):
248
249
250
251
        ema = checkpoint["ema"]
        if(not self.model.template_config.enabled):
            ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k}
        self.ema.load_state_dict(ema)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
252

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
253
254
255
    def on_save_checkpoint(self, checkpoint):
        checkpoint["ema"] = self.ema.state_dict()

256
257
258
    def resume_last_lr_step(self, lr_step):
        self.last_lr_step = lr_step

259
260
261
262
263
264
265
266
267
268
269
    def load_from_jax(self, jax_path):
        model_basename = os.path.splitext(
                os.path.basename(
                    os.path.normpath(jax_path)
                )
        )[0]
        model_version = "_".join(model_basename.split("_")[1:])
        import_jax_weights_(
                self.model, jax_path, version=model_version
        )

270

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
271
def main(args):
272
273
274
    if(args.seed is not None):
        seed_everything(args.seed) 

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
275
    config = model_config(
276
        args.config_preset, 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
277
        train=True, 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
278
        low_prec=(str(args.precision) == "16")
279
    ) 
280
281
    model_module = OpenFoldWrapper(config)

282
    if(args.resume_from_ckpt):
283
284
285
286
287
        if(os.path.isdir(args.resume_from_ckpt)):  
            last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
        else:
            sd = torch.load(args.resume_from_ckpt)
            last_global_step = int(sd['global_step'])
288
289
        model_module.resume_last_lr_step(last_global_step)
        logging.info("Successfully loaded last lr step...")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
290
    if(args.resume_from_ckpt and args.resume_model_weights_only):
291
292
293
294
        if(os.path.isdir(args.resume_from_ckpt)):
            sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
        else:
            sd = torch.load(args.resume_from_ckpt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
295
296
297
        sd = {k[len("module."):]:v for k,v in sd.items()}
        model_module.load_state_dict(sd)
        logging.info("Successfully loaded model weights...")
Lucas Bickmann's avatar
Lucas Bickmann committed
298
299
300
    if(args.resume_from_jax_params):
        model_module.load_from_jax(args.resume_from_jax_params)
        logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
301
 
302
    # TorchScript components of the model
303
304
    if(args.script_modules):
        script_preset_(model_module)
305

306
307
    if "multimer" in args.config_preset:
        data_module = OpenFoldMultimerDataModule(
308
309
310
311
        config=config.data, 
        batch_seed=args.seed,
        **vars(args)
    )
312
313
314
315
316
317
    else:
        data_module = OpenFoldDataModule(
            config=config.data, 
            batch_seed=args.seed,
            **vars(args)
        )
318

319
320
    data_module.prepare_data()
    data_module.setup()
321
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
322
    callbacks = []
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
323
    if(args.checkpoint_every_epoch):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
324
        mc = ModelCheckpoint(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
325
            every_n_epochs=1,
326
327
            auto_insert_metric_name=False,
            save_top_k=-1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
328
329
330
331
332
        )
        callbacks.append(mc)

    if(args.early_stopping):
        es = EarlyStoppingVerbose(
333
            monitor="val/lddt_ca",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
334
335
336
            min_delta=args.min_delta,
            patience=args.patience,
            verbose=False,
337
            mode="max",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
338
339
340
341
            check_finite=True,
            strict=True,
        )
        callbacks.append(es)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
342

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
343
    if(args.log_performance):
Marta's avatar
Marta committed
344
345
        global_batch_size = args.num_nodes * args.gpus
        perf = PerformanceLoggingCallback(
Marta's avatar
Marta committed
346
            log_file=os.path.join(args.output_dir, "performance_log.json"),
Marta's avatar
Marta committed
347
348
349
            global_batch_size=global_batch_size,
        )
        callbacks.append(perf)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
350

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
351
352
353
354
    if(args.log_lr):
        lr_monitor = LearningRateMonitor(logging_interval="step")
        callbacks.append(lr_monitor)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
355
356
357
358
359
360
361
362
363
364
365
    loggers = []
    if(args.wandb):
        wdb_logger = WandbLogger(
            name=args.experiment_name,
            save_dir=args.output_dir,
            id=args.wandb_id,
            project=args.wandb_project,
            **{"entity": args.wandb_entity}
        )
        loggers.append(wdb_logger)

366
    if(args.deepspeed_config_path is not None):
367
368
369
        strategy = DeepSpeedPlugin(
            config=args.deepspeed_config_path,
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
370
371
        if(args.wandb):
            wdb_logger.experiment.save(args.deepspeed_config_path)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
372
            wdb_logger.experiment.save("openfold/config.py")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
373
    elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
374
        strategy = DDPPlugin(find_unused_parameters=False)
375
376
    else:
        strategy = None
377
378
379
380
381
382
 
    if(args.wandb):
        freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
        os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
        wdb_logger.experiment.save(f"{freeze_path}")

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
383
384
    trainer = pl.Trainer.from_argparse_args(
        args,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
385
        default_root_dir=args.output_dir,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
386
        strategy=strategy,
Marta's avatar
Marta committed
387
        callbacks=callbacks,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
388
        logger=loggers,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
389
390
391
392
393
394
395
396
397
398
399
    )

    if(args.resume_model_weights_only):
        ckpt_path = None
    else:
        ckpt_path = args.resume_from_ckpt

    trainer.fit(
        model_module, 
        datamodule=data_module,
        ckpt_path=ckpt_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
400
401
402
    )


Marta's avatar
Marta committed
403
404
405
406
407
408
409
410
411
412
def bool_type(bool_str: str):
    bool_str_lower = bool_str.lower()
    if bool_str_lower in ('false', 'f', 'no', 'n', '0'):
        return False
    elif bool_str_lower in ('true', 't', 'yes', 'y', '1'):
        return True
    else:
        raise ValueError(f'Cannot interpret {bool_str} as bool')


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
413
414
415
416
417
418
419
420
421
422
423
424
425
426
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "train_data_dir", type=str,
        help="Directory containing training mmCIF files"
    )
    parser.add_argument(
        "train_alignment_dir", type=str,
        help="Directory containing precomputed training alignments"
    )
    parser.add_argument(
        "template_mmcif_dir", type=str,
        help="Directory containing mmCIF files to search for templates"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
427
428
429
430
431
    parser.add_argument(
        "output_dir", type=str,
        help='''Directory in which to output checkpoints, logs, etc. Ignored
                if not on rank 0'''
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
432
433
    parser.add_argument(
        "max_template_date", type=str,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
434
435
        help='''Cutoff for all templates. In training mode, templates are also 
                filtered by the release date of the target'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
436
    )
437
438
439
440
    parser.add_argument(
        "--train_mmcif_data_cache_path", type=str, default=None,
        help="path to the json file which records all the information of mmcif structures used during training"
    )
441
442
443
444
445
446
447
448
    parser.add_argument(
        "--distillation_data_dir", type=str, default=None,
        help="Directory containing training PDB files"
    )
    parser.add_argument(
        "--distillation_alignment_dir", type=str, default=None,
        help="Directory containing precomputed distillation alignments"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
449
450
451
452
453
454
455
456
    parser.add_argument(
        "--val_data_dir", type=str, default=None,
        help="Directory containing validation mmCIF files"
    )
    parser.add_argument(
        "--val_alignment_dir", type=str, default=None,
        help="Directory containing precomputed validation alignments"
    )
457
458
    parser.add_argument(
        "--val_mmcif_data_cache_path", type=str, default=None,
Dingquan Yu's avatar
Dingquan Yu committed
459
        help="path to the json file which records all the information of mmcif structures used during validation"
460
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
461
462
463
464
465
    parser.add_argument(
        "--kalign_binary_path", type=str, default='/usr/bin/kalign',
        help="Path to the kalign binary"
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
466
467
468
469
        "--train_filter_path", type=str, default=None,
        help='''Optional path to a text file containing names of training
                examples to include, one per line. Used to filter the training 
                set'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
470
471
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
472
473
        "--distillation_filter_path", type=str, default=None,
        help="""See --train_filter_path"""
474
    )
475
476
477
478
479
    parser.add_argument(
        "--obsolete_pdbs_file_path", type=str, default=None,
        help="""Path to obsolete.dat file containing list of obsolete PDBs and 
             their replacements."""
    )
480
481
    parser.add_argument(
        "--template_release_dates_cache_path", type=str, default=None,
482
483
        help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
                files."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
484
485
    )
    parser.add_argument(
Marta's avatar
Marta committed
486
        "--use_small_bfd", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
487
488
489
        help="Whether to use a reduced version of the BFD database"
    )
    parser.add_argument(
490
491
        "--seed", type=int, default=None,
        help="Random seed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
492
    )
493
494
495
496
    parser.add_argument(
        "--deepspeed_config_path", type=str, default=None,
        help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
497
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
498
499
        "--checkpoint_every_epoch", action="store_true", default=False,
        help="""Whether to checkpoint at the end of every training epoch"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
500
501
    )
    parser.add_argument(
Marta's avatar
Marta committed
502
        "--early_stopping", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
503
504
505
506
507
508
509
510
511
512
513
        help="Whether to stop training when validation loss fails to decrease"
    )
    parser.add_argument(
        "--min_delta", type=float, default=0,
        help="""The smallest decrease in validation loss that counts as an 
                improvement for the purposes of early stopping"""
    )
    parser.add_argument(
        "--patience", type=int, default=3,
        help="Early stopping patience"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
514
515
516
517
518
    parser.add_argument(
        "--resume_from_ckpt", type=str, default=None,
        help="Path to a model checkpoint from which to restore training state"
    )
    parser.add_argument(
Marta's avatar
Marta committed
519
        "--resume_model_weights_only", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
520
521
        help="Whether to load just model weights as opposed to training state"
    )
Lucas Bickmann's avatar
Lucas Bickmann committed
522
    parser.add_argument(
523
524
        "--resume_from_jax_params", type=str, default=None,
        help="""Path to an .npz JAX parameter file with which to initialize the model"""
Lucas Bickmann's avatar
Lucas Bickmann committed
525
    )
Marta's avatar
Marta committed
526
    parser.add_argument(
527
        "--log_performance", type=bool_type, default=False,
Marta's avatar
Marta committed
528
529
        help="Measure performance"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
530
531
    parser.add_argument(
        "--wandb", action="store_true", default=False,
532
        help="Whether to log metrics to Weights & Biases"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
533
534
535
    )
    parser.add_argument(
        "--experiment_name", type=str, default=None,
536
        help="Name of the current experiment. Used for wandb logging"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
537
538
539
    )
    parser.add_argument(
        "--wandb_id", type=str, default=None,
540
        help="ID of a previous run to be resumed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
541
542
543
    )
    parser.add_argument(
        "--wandb_project", type=str, default=None,
544
        help="Name of the wandb project to which this run will belong"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
545
546
547
    )
    parser.add_argument(
        "--wandb_entity", type=str, default=None,
548
        help="wandb username or team name to which runs are attributed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
549
    )
550
551
552
553
    parser.add_argument(
        "--script_modules", type=bool_type, default=False,
        help="Whether to TorchScript eligible components of them model"
    )
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
554
    parser.add_argument(
555
        "--train_chain_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
556
557
    )
    parser.add_argument(
558
        "--distillation_chain_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
559
560
561
    )
    parser.add_argument(
        "--train_epoch_len", type=int, default=10000,
562
563
564
565
566
567
        help=(
            "The virtual length of each training epoch. Stochastic filtering "
            "of training data means that training datasets have no "
            "well-defined length. This virtual length affects frequency of "
            "validation & checkpointing (by default, one of each per epoch)."
        )
568
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
569
    parser.add_argument(
570
571
        "--log_lr", action="store_true", default=False,
        help="Whether to log the actual learning rate"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
572
    )
573
    parser.add_argument(
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
        "--config_preset", type=str, default="initial_training",
        help=(
            'Config setting. Choose e.g. "initial_training", "finetuning", '
            '"model_1", etc. By default, the actual values in the config are '
            'used.'
        )
    )
    parser.add_argument(
        "--_distillation_structure_index_path", type=str, default=None,
    )
    parser.add_argument(
        "--alignment_index_path", type=str, default=None,
        help="Training alignment index. See the README for instructions."
    )
    parser.add_argument(
        "--distillation_alignment_index_path", type=str, default=None,
        help="Distillation alignment index. See the README for instructions."
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
591
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
592
    parser = pl.Trainer.add_argparse_args(parser)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
593
594
   
    # Disable the initial validation pass
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
595
596
597
598
    parser.set_defaults(
        num_sanity_val_steps=0,
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
599
    # Remove some buggy/redundant arguments introduced by the Trainer
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
600
601
602
603
604
    remove_arguments(
        parser, 
        [
            "--accelerator", 
            "--resume_from_checkpoint",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
605
606
            "--reload_dataloaders_every_epoch",
            "--reload_dataloaders_every_n_epochs",
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
607
608
        ]
    ) 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
609

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
610
611
    args = parser.parse_args()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
612
613
614
615
616
    if(args.seed is None and 
        ((args.gpus is not None and args.gpus > 1) or 
         (args.num_nodes is not None and args.num_nodes > 1))):
        raise ValueError("For distributed training, --seed must be specified")

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
617
    if(str(args.precision) == "16" and args.deepspeed_config_path is not None):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
618
619
        raise ValueError("DeepSpeed and FP16 training are not compatible")

Lucas Bickmann's avatar
Lucas Bickmann committed
620
    if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
621
622
        raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path")

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
623
    # This re-applies the training-time filters at the beginning of every epoch
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
624
    args.reload_dataloaders_every_n_epochs = 1
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
625

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
626
    main(args)