train_openfold.py 7.89 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
import argparse
import logging
import os

5
#os.environ["CUDA_VISIBLE_DEVICES"] = "5"
6
7
8
9
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
10

11
import random
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
12
13
import time

14
import numpy as np
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
15
import pytorch_lightning as pl
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
17
18
19
20
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin
import torch

from openfold.config import model_config
21
22
from openfold.data.data_modules import (
    OpenFoldDataModule,
23
    DummyDataLoader,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
24
)
25
from openfold.model.model import AlphaFold
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
26
27
28
from openfold.utils.callbacks import (
    EarlyStoppingVerbose,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
29
30
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss
31
from openfold.utils.seed import seed_everything
32
from openfold.utils.tensor_utils import tensor_tree_map
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
33
34
35
36
37
38


class OpenFoldWrapper(pl.LightningModule):
    def __init__(self, config):
        super(OpenFoldWrapper, self).__init__()
        self.config = config
39
        self.model = AlphaFold(config)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
40
        self.loss = AlphaFoldLoss(config.loss)
41
42
43
        self.ema = ExponentialMovingAverage(
            model=self.model, decay=config.ema.decay
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
44
45
46
47
48

    def forward(self, batch):
        return self.model(batch)

    def training_step(self, batch, batch_idx):
49
50
51
        if(self.ema.device != batch["aatype"].device):
            self.ema.to(batch["aatype"].device)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
52
53
54
55
56
57
58
59
60
        # Run the model
        outputs = self(batch)
        
        # Remove the recycling dimension
        batch = tensor_tree_map(lambda t: t[..., -1], batch)

        # Compute loss
        loss = self.loss(outputs, batch)

61
        return {"loss": loss}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
62

63
64
65
66
67
68
69
70
71
72
73
    def validation_step(self, batch, batch_idx):
        # At the start of validation, load the EMA weights
        if(self.cached_weights is None):
            self.cached_weights = model.state_dict()
            self.model.load_state_dict(self.ema.state_dict()["params"])
        
        # Calculate validation loss
        outputs = self(batch)
        batch = tensor_tree_map(lambda t: t[..., -1], batch)
        loss = self.loss(outputs, batch)
        return {"val_loss": loss}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
74

75
76
77
78
    def validation_epoch_end(self, _):
        # Restore the model weights to normal
        self.model.load_state_dict(self.cached_weights)
        self.cached_weights = None
79

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
80
81
82
83
84
85
86
87
88
89
90
    def configure_optimizers(self, 
        learning_rate: float = 1e-3,
        eps: float = 1e-8
    ) -> torch.optim.Adam:
        # Ignored as long as a DeepSpeed optimizer is configured
        return torch.optim.Adam(
            self.model.parameters(), 
            lr=learning_rate, 
            eps=eps
        )

91
92
    def on_before_zero_grad(self, *args, **kwargs):
        self.ema.update(self.model)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
93

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
94
95
96
    def on_save_checkpoint(self, checkpoint):
        checkpoint["ema"] = self.ema.state_dict()

97

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
98
def main(args):
99
100
101
    if(args.seed is not None):
        seed_everything(args.seed) 

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
102
103
104
105
    config = model_config(
        "model_1", 
        train=True, 
        low_prec=(args.precision == 16)
106
    ) 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
107

108
109
110
111
112
113
114
115
116
    model_module = OpenFoldWrapper(config) 
    #data_module = DummyDataLoader("batch.pickle")
    data_module = OpenFoldDataModule(
        config=config.data, 
        batch_seed=args.seed,
        **vars(args)
    )
    data_module.prepare_data()
    data_module.setup()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
   
    callbacks = []
    if(args.checkpoint_best_val):
        checkpoint_dir = os.path.join(args.output_dir, "checkpoints")
        mc = ModelCheckpoint(
            dirpath=checkpoint_dir,
            filename="openfold_{epoch}_{step}_{val_loss:.2f}",
            monitor="val_loss",
        )
        callbacks.append(mc)

    if(args.early_stopping):
        es = EarlyStoppingVerbose(
            monitor="val_loss",
            min_delta=args.min_delta,
            patience=args.patience,
            verbose=False,
            mode="min",
            check_finite=True,
            strict=True,
        )
        callbacks.append(es)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
140
    plugins = []
141
142
143
    if(args.deepspeed_config_path is not None):
        plugins.append(DeepSpeedPlugin(config=args.deepspeed_config_path))
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
144
145
146
147
148
    trainer = pl.Trainer.from_argparse_args(
        args,
        plugins=plugins,
    )

149
    trainer.fit(model_module, datamodule=data_module)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
150
    trainer.save_checkpoint("final.ckpt")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "train_data_dir", type=str,
        help="Directory containing training mmCIF files"
    )
    parser.add_argument(
        "train_alignment_dir", type=str,
        help="Directory containing precomputed training alignments"
    )
    parser.add_argument(
        "template_mmcif_dir", type=str,
        help="Directory containing mmCIF files to search for templates"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
167
168
169
170
171
    parser.add_argument(
        "output_dir", type=str,
        help='''Directory in which to output checkpoints, logs, etc. Ignored
                if not on rank 0'''
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
172
173
    parser.add_argument(
        "max_template_date", type=str,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
174
175
        help='''Cutoff for all templates. In training mode, templates are also 
                filtered by the release date of the target'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
176
    )
177
178
179
180
181
182
183
184
    parser.add_argument(
        "--distillation_data_dir", type=str, default=None,
        help="Directory containing training PDB files"
    )
    parser.add_argument(
        "--distillation_alignment_dir", type=str, default=None,
        help="Directory containing precomputed distillation alignments"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    parser.add_argument(
        "--val_data_dir", type=str, default=None,
        help="Directory containing validation mmCIF files"
    )
    parser.add_argument(
        "--val_alignment_dir", type=str, default=None,
        help="Directory containing precomputed validation alignments"
    )
    parser.add_argument(
        "--kalign_binary_path", type=str, default='/usr/bin/kalign',
        help="Path to the kalign binary"
    )
    parser.add_argument(
        "--train_mapping_path", type=str, default=None,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
199
        help='''Optional path to a .json file containing a mapping from
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
200
                consecutive numerical indices to sample names. Used to filter
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
201
                the training set'''
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
202
203
    )
    parser.add_argument(
204
205
206
207
208
        "--distillation_mapping_path", type=str, default=None,
        help="""See --train_mapping_path"""
    )
    parser.add_argument(
        "--template_release_dates_cache_path", type=str, default=None,
209
210
        help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
                files."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
211
212
213
214
215
216
    )
    parser.add_argument(
        "--use_small_bfd", type=bool, default=False,
        help="Whether to use a reduced version of the BFD database"
    )
    parser.add_argument(
217
218
        "--seed", type=int, default=None,
        help="Random seed"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
219
    )
220
221
222
223
    parser.add_argument(
        "--deepspeed_config_path", type=str, default=None,
        help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
224
    parser.add_argument(
225
        "--checkpoint_best_val", type=bool, default=True,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
        help="""Whether to save the model parameters that perform best during
                validation"""
    )
    parser.add_argument(
        "--early_stopping", type=bool, default=False,
        help="Whether to stop training when validation loss fails to decrease"
    )
    parser.add_argument(
        "--min_delta", type=float, default=0,
        help="""The smallest decrease in validation loss that counts as an 
                improvement for the purposes of early stopping"""
    )
    parser.add_argument(
        "--patience", type=int, default=3,
        help="Early stopping patience"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
242
243
244
245
246
247
248
249
    parser = pl.Trainer.add_argparse_args(parser)
    
    parser.set_defaults(
        num_sanity_val_steps=0,
    )

    args = parser.parse_args()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
250
251
252
253
254
    if(args.seed is None and 
        ((args.gpus is not None and args.gpus > 1) or 
         (args.num_nodes is not None and args.num_nodes > 1))):
        raise ValueError("For distributed training, --seed must be specified")

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
255
    main(args)