"googlemock/include/vscode:/vscode.git/clone" did not exist on "e7ed50fd137dd7626bbb21dfc41982454dcff69b"
train_openfold.py 12.9 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
import argparse
import logging
import os

5
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
6
7
8
9
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"

10
import random
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
11
12
import time

13
import numpy as np
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
14
import pytorch_lightning as pl
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
15
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
from pytorch_lightning.loggers import WandbLogger
17
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
18
from pytorch_lightning.plugins.environments import SLURMEnvironment
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
20
21
import torch

from openfold.config import model_config
22
23
from openfold.data.data_modules import (
    OpenFoldDataModule,
24
    DummyDataLoader,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
25
)
26
from openfold.model.model import AlphaFold
27
from openfold.model.torchscript import script_preset_
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
28
29
30
from openfold.utils.callbacks import (
    EarlyStoppingVerbose,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
31
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
32
from openfold.utils.argparse import remove_arguments
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
33
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
34
from openfold.utils.seed import seed_everything
35
from openfold.utils.tensor_utils import tensor_tree_map
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
36
37
38
from scripts.zero_to_fp32 import (
    get_fp32_state_dict_from_zero_checkpoint
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
39

Marta's avatar
Marta committed
40
41
from openfold.utils.logger import PerformanceLoggingCallback

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
42
43
44
45
46

class OpenFoldWrapper(pl.LightningModule):
    def __init__(self, config):
        super(OpenFoldWrapper, self).__init__()
        self.config = config
47
        self.model = AlphaFold(config)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
48
        self.loss = AlphaFoldLoss(config.loss)
49
50
51
        self.ema = ExponentialMovingAverage(
            model=self.model, decay=config.ema.decay
        )
52
53
        
        self.cached_weights = None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
54
55
56
57
58

    def forward(self, batch):
        return self.model(batch)

    def training_step(self, batch, batch_idx):
59
60
61
        if(self.ema.device != batch["aatype"].device):
            self.ema.to(batch["aatype"].device)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
62
63
64
65
66
67
68
69
70
        # Run the model
        outputs = self(batch)
        
        # Remove the recycling dimension
        batch = tensor_tree_map(lambda t: t[..., -1], batch)

        # Compute loss
        loss = self.loss(outputs, batch)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
71
        self.log("train/loss", loss, logger=True)
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
72

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
73
        return loss
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
74

75
76
77
    def validation_step(self, batch, batch_idx):
        # At the start of validation, load the EMA weights
        if(self.cached_weights is None):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
78
            self.cached_weights = self.model.state_dict()
79
            self.model.load_state_dict(self.ema.state_dict()["params"])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
80
       
81
82
83
        # Calculate validation loss
        outputs = self(batch)
        batch = tensor_tree_map(lambda t: t[..., -1], batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
84
85
86
87
88
89
90
91
        loss = lddt_ca(
            outputs["final_atom_positions"],
            batch["all_atom_positions"],
            batch["all_atom_mask"],
            eps=self.config.globals.eps,
            per_residue=False,
        )
        self.log("val/loss", loss, logger=True)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
92

93
94
95
96
    def validation_epoch_end(self, _):
        # Restore the model weights to normal
        self.model.load_state_dict(self.cached_weights)
        self.cached_weights = None
97

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
98
99
    def configure_optimizers(self, 
        learning_rate: float = 1e-3,
100
        eps: float = 1e-5,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
101
102
103
104
105
106
107
108
    ) -> torch.optim.Adam:
        # Ignored as long as a DeepSpeed optimizer is configured
        return torch.optim.Adam(
            self.model.parameters(), 
            lr=learning_rate, 
            eps=eps
        )

109
110
    def on_before_zero_grad(self, *args, **kwargs):
        self.ema.update(self.model)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
111

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
112
113
114
    def on_load_checkpoint(self, checkpoint):
        self.ema.load_state_dict(checkpoint["ema"])

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
115
116
117
    def on_save_checkpoint(self, checkpoint):
        checkpoint["ema"] = self.ema.state_dict()

118

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
119
def main(args):
120
121
122
    if(args.seed is not None):
        seed_everything(args.seed) 

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
123
124
125
    config = model_config(
        "model_1", 
        train=True, 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
126
        low_prec=(args.precision == "16")
127
    ) 
Gustaf's avatar
Gustaf committed
128
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
129
130
131
132
133
134
    model_module = OpenFoldWrapper(config)
    if(args.resume_from_ckpt and args.resume_model_weights_only):
        sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
        sd = {k[len("module."):]:v for k,v in sd.items()}
        model_module.load_state_dict(sd)
        logging.info("Successfully loaded model weights...")
135
136

    # TorchScript components of the model
137
138
    if(args.script_modules):
        script_preset_(model_module)
139

140
141
142
143
144
145
    #data_module = DummyDataLoader("batch.pickle")
    data_module = OpenFoldDataModule(
        config=config.data, 
        batch_seed=args.seed,
        **vars(args)
    )
146

147
148
    data_module.prepare_data()
    data_module.setup()
149
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
150
151
152
153
154
155
    callbacks = []
    if(args.checkpoint_best_val):
        checkpoint_dir = os.path.join(args.output_dir, "checkpoints")
        mc = ModelCheckpoint(
            dirpath=checkpoint_dir,
            filename="openfold_{epoch}_{step}_{val_loss:.2f}",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
156
            monitor="val/loss",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
157
158
159
160
161
        )
        callbacks.append(mc)

    if(args.early_stopping):
        es = EarlyStoppingVerbose(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
162
            monitor="val/loss",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
163
164
165
166
167
168
169
170
            min_delta=args.min_delta,
            patience=args.patience,
            verbose=False,
            mode="min",
            check_finite=True,
            strict=True,
        )
        callbacks.append(es)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
171

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
172
    if(args.log_performance):
Marta's avatar
Marta committed
173
174
        global_batch_size = args.num_nodes * args.gpus
        perf = PerformanceLoggingCallback(
Marta's avatar
Marta committed
175
            log_file=os.path.join(args.output_dir, "performance_log.json"),
Marta's avatar
Marta committed
176
177
178
            global_batch_size=global_batch_size,
        )
        callbacks.append(perf)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
179

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
180
181
182
183
184
185
186
187
188
189
190
    loggers = []
    if(args.wandb):
        wdb_logger = WandbLogger(
            name=args.experiment_name,
            save_dir=args.output_dir,
            id=args.wandb_id,
            project=args.wandb_project,
            **{"entity": args.wandb_entity}
        )
        loggers.append(wdb_logger)

191
    if(args.deepspeed_config_path is not None):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
192
193
194
195
196
        #if "SLURM_JOB_ID" in os.environ:
        #    cluster_environment = SLURMEnvironment()
        #else:
        #    cluster_environment = None
      
197
198
        strategy = DeepSpeedPlugin(
            config=args.deepspeed_config_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
199
        #    cluster_environment=cluster_environment,
200
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
201
202
203
        if(args.wandb):
            wdb_logger.experiment.save(args.deepspeed_config_path)
    elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
204
        strategy = DDPPlugin(find_unused_parameters=False)
205
206
    else:
        strategy = None
207
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
208
209
    trainer = pl.Trainer.from_argparse_args(
        args,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
210
        default_root_dir=args.output_dir,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
211
        strategy=strategy,
Marta's avatar
Marta committed
212
        callbacks=callbacks,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
213
        logger=loggers,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
214
215
216
217
218
219
220
221
222
223
224
    )

    if(args.resume_model_weights_only):
        ckpt_path = None
    else:
        ckpt_path = args.resume_from_ckpt

    trainer.fit(
        model_module, 
        datamodule=data_module,
        ckpt_path=ckpt_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
225
226
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
227
    trainer.save_checkpoint(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
228
        os.path.join(args.output_dir, "checkpoints", "final.ckpt")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
229
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
230
231


Marta's avatar
Marta committed
232
233
234
235
236
237
238
239
240
241
def bool_type(bool_str: str):
    bool_str_lower = bool_str.lower()
    if bool_str_lower in ('false', 'f', 'no', 'n', '0'):
        return False
    elif bool_str_lower in ('true', 't', 'yes', 'y', '1'):
        return True
    else:
        raise ValueError(f'Cannot interpret {bool_str} as bool')


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
242
243
244
245
246
247
248
249
250
251
252
253
254
255
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "train_data_dir", type=str,
        help="Directory containing training mmCIF files"
    )
    parser.add_argument(
        "train_alignment_dir", type=str,
        help="Directory containing precomputed training alignments"
    )
    parser.add_argument(
        "template_mmcif_dir", type=str,
        help="Directory containing mmCIF files to search for templates"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
256
257
258
259
260
    parser.add_argument(
        "output_dir", type=str,
        help='''Directory in which to output checkpoints, logs, etc. Ignored
                if not on rank 0'''
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
261
262
    parser.add_argument(
        "max_template_date", type=str,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
263
264
        help='''Cutoff for all templates. In training mode, templates are also 
                filtered by the release date of the target'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
265
    )
266
267
268
269
270
271
272
273
    parser.add_argument(
        "--distillation_data_dir", type=str, default=None,
        help="Directory containing training PDB files"
    )
    parser.add_argument(
        "--distillation_alignment_dir", type=str, default=None,
        help="Directory containing precomputed distillation alignments"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    parser.add_argument(
        "--val_data_dir", type=str, default=None,
        help="Directory containing validation mmCIF files"
    )
    parser.add_argument(
        "--val_alignment_dir", type=str, default=None,
        help="Directory containing precomputed validation alignments"
    )
    parser.add_argument(
        "--kalign_binary_path", type=str, default='/usr/bin/kalign',
        help="Path to the kalign binary"
    )
    parser.add_argument(
        "--train_mapping_path", type=str, default=None,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
288
        help='''Optional path to a .json file containing a mapping from
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
289
                consecutive numerical indices to sample names. Used to filter
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
290
                the training set'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
291
292
    )
    parser.add_argument(
293
294
295
        "--distillation_mapping_path", type=str, default=None,
        help="""See --train_mapping_path"""
    )
296
297
298
299
300
    parser.add_argument(
        "--obsolete_pdbs_file_path", type=str, default=None,
        help="""Path to obsolete.dat file containing list of obsolete PDBs and 
             their replacements."""
    )
301
302
    parser.add_argument(
        "--template_release_dates_cache_path", type=str, default=None,
303
304
        help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
                files."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
305
306
    )
    parser.add_argument(
Marta's avatar
Marta committed
307
        "--use_small_bfd", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
308
309
310
        help="Whether to use a reduced version of the BFD database"
    )
    parser.add_argument(
311
312
        "--seed", type=int, default=None,
        help="Random seed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
313
    )
314
315
316
317
    parser.add_argument(
        "--deepspeed_config_path", type=str, default=None,
        help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
318
    parser.add_argument(
Marta's avatar
Marta committed
319
        "--checkpoint_best_val", type=bool_type, default=True,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
320
321
322
323
        help="""Whether to save the model parameters that perform best during
                validation"""
    )
    parser.add_argument(
Marta's avatar
Marta committed
324
        "--early_stopping", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
325
326
327
328
329
330
331
332
333
334
335
        help="Whether to stop training when validation loss fails to decrease"
    )
    parser.add_argument(
        "--min_delta", type=float, default=0,
        help="""The smallest decrease in validation loss that counts as an 
                improvement for the purposes of early stopping"""
    )
    parser.add_argument(
        "--patience", type=int, default=3,
        help="Early stopping patience"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
336
337
338
339
340
    parser.add_argument(
        "--resume_from_ckpt", type=str, default=None,
        help="Path to a model checkpoint from which to restore training state"
    )
    parser.add_argument(
Marta's avatar
Marta committed
341
        "--resume_model_weights_only", type=bool_type, default=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
342
343
        help="Whether to load just model weights as opposed to training state"
    )
Marta's avatar
Marta committed
344
    parser.add_argument(
345
        "--log_performance", type=bool_type, default=False,
Marta's avatar
Marta committed
346
347
        help="Measure performance"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
    parser.add_argument(
        "--wandb", action="store_true", default=False,
    )
    parser.add_argument(
        "--experiment_name", type=str, default=None,
    )
    parser.add_argument(
        "--wandb_id", type=str, default=None,
    )
    parser.add_argument(
        "--wandb_project", type=str, default=None,
    )
    parser.add_argument(
        "--wandb_entity", type=str, default=None,
    )
363
364
365
366
    parser.add_argument(
        "--script_modules", type=bool_type, default=False,
        help="Whether to TorchScript eligible components of them model"
    )
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
367
368
369
370
371
372
373
374
375
    parser.add_argument(
        "--train_prot_data_cache_path", type=str, default=None,
    )
    parser.add_argument(
        "--distillation_prot_data_cache_path", type=str, default=None,
    )
    parser.add_argument(
        "--train_epoch_len", type=int, default=10000,
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
376
377
378
379
380
381
    parser.add_argument(
        "--obsolete_pdbs_file_path", type=str,
    )
    parser.add_argument(
        "--_alignment_index_path", type=str, default=None,
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
382
    parser = pl.Trainer.add_argparse_args(parser)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
383
384
   
    # Disable the initial validation pass
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
385
386
387
388
    parser.set_defaults(
        num_sanity_val_steps=0,
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
389
    # Remove some buggy/redundant arguments introduced by the Trainer
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
390
391
392
393
394
    remove_arguments(
        parser, 
        [
            "--accelerator", 
            "--resume_from_checkpoint",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
395
396
            "--reload_dataloaders_every_epoch",
            "--reload_dataloaders_every_n_epochs",
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
397
398
        ]
    ) 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
399

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
400
401
    args = parser.parse_args()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
402
403
404
405
406
    if(args.seed is None and 
        ((args.gpus is not None and args.gpus > 1) or 
         (args.num_nodes is not None and args.num_nodes > 1))):
        raise ValueError("For distributed training, --seed must be specified")

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
407
    # This re-applies the training-time filters at the beginning of every epoch
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
408
    args.reload_dataloaders_every_n_epochs = 1
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
409

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
410
    main(args)