train_openfold.py 13.1 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
import argparse
import logging
import os

5
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
6
7
8
9
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"

10
import random
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
11
12
import time

13
import numpy as np
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
14
import pytorch_lightning as pl
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
15
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
17
from pytorch_lightning.loggers import WandbLogger
18
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
19
from pytorch_lightning.plugins.environments import SLURMEnvironment
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
20
21
22
import torch

from openfold.config import model_config
23
24
from openfold.data.data_modules import (
    OpenFoldDataModule,
25
    DummyDataLoader,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
26
)
27
from openfold.model.model import AlphaFold
28
from openfold.model.torchscript import script_preset_
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
29
30
31
from openfold.utils.callbacks import (
    EarlyStoppingVerbose,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
32
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
33
from openfold.utils.argparse import remove_arguments
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
34
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
35
from openfold.utils.seed import seed_everything
36
from openfold.utils.tensor_utils import tensor_tree_map
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
37
38
39
from scripts.zero_to_fp32 import (
    get_fp32_state_dict_from_zero_checkpoint
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
40

Marta's avatar
Marta committed
41
42
from openfold.utils.logger import PerformanceLoggingCallback

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
43
44
45
46
47

class OpenFoldWrapper(pl.LightningModule):
    def __init__(self, config):
        super(OpenFoldWrapper, self).__init__()
        self.config = config
48
        self.model = AlphaFold(config)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
49
        self.loss = AlphaFoldLoss(config.loss)
50
51
52
        self.ema = ExponentialMovingAverage(
            model=self.model, decay=config.ema.decay
        )
53
54
        
        self.cached_weights = None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
55
        self.last_lr_step = 0
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
56
57
58
59
60

    def forward(self, batch):
        return self.model(batch)

    def training_step(self, batch, batch_idx):
61
62
63
        if(self.ema.device != batch["aatype"].device):
            self.ema.to(batch["aatype"].device)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
64
65
66
67
68
69
70
        # Run the model
        outputs = self(batch)
        
        # Remove the recycling dimension
        batch = tensor_tree_map(lambda t: t[..., -1], batch)

        # Compute loss
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
71
        loss = self.loss(outputs, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
72

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
73
        self.log("train/loss", loss, on_step=True, logger=True)
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
74

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
75
        return loss
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
76

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
77
78
79
80
81
    def training_step_end(self, outputs):
        # Temporary measure to address DeepSpeed scheduler bug
        if(self.trainer.global_step != self.last_lr_step):
            self.lr_schedulers().step()
            self.last_lr_step = self.trainer.global_step
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
82

83
84
85
    def validation_step(self, batch, batch_idx):
        # At the start of validation, load the EMA weights
        if(self.cached_weights is None):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
86
            self.cached_weights = self.model.state_dict()
87
            self.model.load_state_dict(self.ema.state_dict()["params"])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
88
       
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
89
        # Calculate validation loss
90
91
        outputs = self(batch)
        batch = tensor_tree_map(lambda t: t[..., -1], batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
92
93
94
95
96
97
        loss = lddt_ca(
            outputs["final_atom_positions"],
            batch["all_atom_positions"],
            batch["all_atom_mask"],
            eps=self.config.globals.eps,
            per_residue=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
98
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
99
        self.log("val/loss", loss, logger=True)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
100

101
102
103
104
    def validation_epoch_end(self, _):
        # Restore the model weights to normal
        self.model.load_state_dict(self.cached_weights)
        self.cached_weights = None
105

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
106
107
    def configure_optimizers(self, 
        learning_rate: float = 1e-3,
108
        eps: float = 1e-5,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
109
110
    ) -> torch.optim.Adam:
        # Ignored as long as a DeepSpeed optimizer is configured
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
111
        return torch.optim.Adam(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
112
113
114
115
            self.model.parameters(), 
            lr=learning_rate, 
            eps=eps
        )
116

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
117
118
    def on_before_zero_grad(self, *args, **kwargs):
        self.ema.update(self.model)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
119

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
120
121
122
    def on_load_checkpoint(self, checkpoint):
        self.ema.load_state_dict(checkpoint["ema"])

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
123
124
125
    def on_save_checkpoint(self, checkpoint):
        checkpoint["ema"] = self.ema.state_dict()

126

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
127
def main(args):
128
129
130
    if(args.seed is not None):
        seed_everything(args.seed) 

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
131
    config = model_config(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
132
        "model_1", 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
133
        train=True, 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
134
        low_prec=(args.precision == "16")
135
    ) 
Gustaf's avatar
Gustaf committed
136
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
137
138
139
140
141
142
    model_module = OpenFoldWrapper(config)
    if(args.resume_from_ckpt and args.resume_model_weights_only):
        sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
        sd = {k[len("module."):]:v for k,v in sd.items()}
        model_module.load_state_dict(sd)
        logging.info("Successfully loaded model weights...")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
143

144
    # TorchScript components of the model
145
146
    if(args.script_modules):
        script_preset_(model_module)
147

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
148
    #data_module = DummyDataLoader("new_batch.pickle")
149
150
151
152
153
    data_module = OpenFoldDataModule(
        config=config.data, 
        batch_seed=args.seed,
        **vars(args)
    )
154

155
156
    data_module.prepare_data()
    data_module.setup()
157
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
158
    callbacks = []
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
159
    if(args.checkpoint_every_epoch):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
160
        mc = ModelCheckpoint(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
161
            every_n_epochs=1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
162
163
164
165
166
        )
        callbacks.append(mc)

    if(args.early_stopping):
        es = EarlyStoppingVerbose(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
167
            monitor="val/loss",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
168
169
170
            min_delta=args.min_delta,
            patience=args.patience,
            verbose=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
171
            mode="min",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
172
173
174
175
            check_finite=True,
            strict=True,
        )
        callbacks.append(es)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
176

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
177
    if(args.log_performance):
Marta's avatar
Marta committed
178
179
        global_batch_size = args.num_nodes * args.gpus
        perf = PerformanceLoggingCallback(
Marta's avatar
Marta committed
180
            log_file=os.path.join(args.output_dir, "performance_log.json"),
Marta's avatar
Marta committed
181
182
183
            global_batch_size=global_batch_size,
        )
        callbacks.append(perf)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
184

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
185
186
187
188
    if(args.log_lr):
        lr_monitor = LearningRateMonitor(logging_interval="step")
        callbacks.append(lr_monitor)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
189
190
191
192
193
194
195
196
197
198
199
    loggers = []
    if(args.wandb):
        wdb_logger = WandbLogger(
            name=args.experiment_name,
            save_dir=args.output_dir,
            id=args.wandb_id,
            project=args.wandb_project,
            **{"entity": args.wandb_entity}
        )
        loggers.append(wdb_logger)

200
    if(args.deepspeed_config_path is not None):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
201
202
203
204
205
        #if "SLURM_JOB_ID" in os.environ:
        #    cluster_environment = SLURMEnvironment()
        #else:
        #    cluster_environment = None
      
206
207
        strategy = DeepSpeedPlugin(
            config=args.deepspeed_config_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
208
        #    cluster_environment=cluster_environment,
209
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
210
211
        if(args.wandb):
            wdb_logger.experiment.save(args.deepspeed_config_path)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
212
            wdb_logger.experiment.save("openfold/config.py")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
213
    elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
214
        strategy = DDPPlugin(find_unused_parameters=False)
215
216
    else:
        strategy = None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
217
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
218
219
    trainer = pl.Trainer.from_argparse_args(
        args,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
220
        default_root_dir=args.output_dir,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
221
        strategy=strategy,
Marta's avatar
Marta committed
222
        callbacks=callbacks,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
223
        logger=loggers,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
224
225
226
227
228
229
230
231
232
233
234
    )

    if(args.resume_model_weights_only):
        ckpt_path = None
    else:
        ckpt_path = args.resume_from_ckpt

    trainer.fit(
        model_module, 
        datamodule=data_module,
        ckpt_path=ckpt_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
235
236
237
    )


Marta's avatar
Marta committed
238
239
240
241
242
243
244
245
246
247
def bool_type(bool_str: str):
    bool_str_lower = bool_str.lower()
    if bool_str_lower in ('false', 'f', 'no', 'n', '0'):
        return False
    elif bool_str_lower in ('true', 't', 'yes', 'y', '1'):
        return True
    else:
        raise ValueError(f'Cannot interpret {bool_str} as bool')


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
248
249
250
251
252
253
254
255
256
257
258
259
260
261
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "train_data_dir", type=str,
        help="Directory containing training mmCIF files"
    )
    parser.add_argument(
        "train_alignment_dir", type=str,
        help="Directory containing precomputed training alignments"
    )
    parser.add_argument(
        "template_mmcif_dir", type=str,
        help="Directory containing mmCIF files to search for templates"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
262
263
264
265
266
    parser.add_argument(
        "output_dir", type=str,
        help='''Directory in which to output checkpoints, logs, etc. Ignored
                if not on rank 0'''
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
267
268
    parser.add_argument(
        "max_template_date", type=str,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
269
270
        help='''Cutoff for all templates. In training mode, templates are also 
                filtered by the release date of the target'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
271
    )
272
273
274
275
276
277
278
279
    parser.add_argument(
        "--distillation_data_dir", type=str, default=None,
        help="Directory containing training PDB files"
    )
    parser.add_argument(
        "--distillation_alignment_dir", type=str, default=None,
        help="Directory containing precomputed distillation alignments"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    parser.add_argument(
        "--val_data_dir", type=str, default=None,
        help="Directory containing validation mmCIF files"
    )
    parser.add_argument(
        "--val_alignment_dir", type=str, default=None,
        help="Directory containing precomputed validation alignments"
    )
    parser.add_argument(
        "--kalign_binary_path", type=str, default='/usr/bin/kalign',
        help="Path to the kalign binary"
    )
    parser.add_argument(
        "--train_mapping_path", type=str, default=None,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
294
        help='''Optional path to a .json file containing a mapping from
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
295
                consecutive numerical indices to sample names. Used to filter
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
296
                the training set'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
297
298
    )
    parser.add_argument(
299
300
301
        "--distillation_mapping_path", type=str, default=None,
        help="""See --train_mapping_path"""
    )
302
303
304
305
306
    parser.add_argument(
        "--obsolete_pdbs_file_path", type=str, default=None,
        help="""Path to obsolete.dat file containing list of obsolete PDBs and 
             their replacements."""
    )
307
308
    parser.add_argument(
        "--template_release_dates_cache_path", type=str, default=None,
309
310
        help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
                files."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
311
312
    )
    parser.add_argument(
Marta's avatar
Marta committed
313
        "--use_small_bfd", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
314
315
316
        help="Whether to use a reduced version of the BFD database"
    )
    parser.add_argument(
317
318
        "--seed", type=int, default=None,
        help="Random seed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
319
    )
320
321
322
323
    parser.add_argument(
        "--deepspeed_config_path", type=str, default=None,
        help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
324
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
325
326
        "--checkpoint_every_epoch", action="store_true", default=False,
        help="""Whether to checkpoint at the end of every training epoch"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
327
328
    )
    parser.add_argument(
Marta's avatar
Marta committed
329
        "--early_stopping", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
330
331
332
333
334
335
336
337
338
339
340
        help="Whether to stop training when validation loss fails to decrease"
    )
    parser.add_argument(
        "--min_delta", type=float, default=0,
        help="""The smallest decrease in validation loss that counts as an 
                improvement for the purposes of early stopping"""
    )
    parser.add_argument(
        "--patience", type=int, default=3,
        help="Early stopping patience"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
341
342
343
344
345
    parser.add_argument(
        "--resume_from_ckpt", type=str, default=None,
        help="Path to a model checkpoint from which to restore training state"
    )
    parser.add_argument(
Marta's avatar
Marta committed
346
        "--resume_model_weights_only", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
347
348
        help="Whether to load just model weights as opposed to training state"
    )
Marta's avatar
Marta committed
349
    parser.add_argument(
350
        "--log_performance", type=bool_type, default=False,
Marta's avatar
Marta committed
351
352
        help="Measure performance"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    parser.add_argument(
        "--wandb", action="store_true", default=False,
    )
    parser.add_argument(
        "--experiment_name", type=str, default=None,
    )
    parser.add_argument(
        "--wandb_id", type=str, default=None,
    )
    parser.add_argument(
        "--wandb_project", type=str, default=None,
    )
    parser.add_argument(
        "--wandb_entity", type=str, default=None,
    )
368
369
370
371
    parser.add_argument(
        "--script_modules", type=bool_type, default=False,
        help="Whether to TorchScript eligible components of them model"
    )
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
372
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
373
        "--train_prot_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
374
375
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
376
        "--distillation_prot_data_cache_path", type=str, default=None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
377
378
379
    )
    parser.add_argument(
        "--train_epoch_len", type=int, default=10000,
380
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
381
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
382
        "--_alignment_index_path", type=str, default=None,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
383
    )
384
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
385
        "--log_lr", action="store_true", default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
386
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
387
    parser = pl.Trainer.add_argparse_args(parser)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
388
389
   
    # Disable the initial validation pass
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
390
391
392
393
    parser.set_defaults(
        num_sanity_val_steps=0,
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
394
    # Remove some buggy/redundant arguments introduced by the Trainer
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
395
396
397
398
399
    remove_arguments(
        parser, 
        [
            "--accelerator", 
            "--resume_from_checkpoint",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
400
401
            "--reload_dataloaders_every_epoch",
            "--reload_dataloaders_every_n_epochs",
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
402
403
        ]
    ) 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
404

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
405
406
    args = parser.parse_args()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
407
408
409
410
411
    if(args.seed is None and 
        ((args.gpus is not None and args.gpus > 1) or 
         (args.num_nodes is not None and args.num_nodes > 1))):
        raise ValueError("For distributed training, --seed must be specified")

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
412
    # This re-applies the training-time filters at the beginning of every epoch
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
413
    args.reload_dataloaders_every_n_epochs = 1
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
414

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
415
    main(args)