Commit fcf9e2f8 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add dummy dataloader, most recent version of training script

parent 6dc8aa7f
...@@ -13,7 +13,8 @@ cases where the *Nature* paper differs from the source, we always defer to the ...@@ -13,7 +13,8 @@ cases where the *Nature* paper differs from the source, we always defer to the
latter. latter.
OpenFold is built to support inference with AlphaFold's original JAX weights. OpenFold is built to support inference with AlphaFold's original JAX weights.
Try it out with our [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb). Try it out with our [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb)
(not yet visible from Colab because the repo is still private).
Unlike DeepMind's public code, OpenFold is also trainable. It can be trained Unlike DeepMind's public code, OpenFold is also trainable. It can be trained
with or without [DeepSpeed](https://github.com/microsoft/deepspeed) and with with or without [DeepSpeed](https://github.com/microsoft/deepspeed) and with
......
...@@ -2,6 +2,7 @@ from functools import partial ...@@ -2,6 +2,7 @@ from functools import partial
import json import json
import logging import logging
import os import os
import pickle
from typing import Optional, Sequence from typing import Optional, Sequence
import ml_collections as mlc import ml_collections as mlc
...@@ -446,3 +447,24 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -446,3 +447,24 @@ class OpenFoldDataModule(pl.LightningDataModule):
num_workers=self.config.data_module.data_loaders.num_workers, num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=self._gen_batch_collator("predict") collate_fn=self._gen_batch_collator("predict")
) )
class DummyDataset(torch.utils.data.Dataset):
def __init__(self, batch_path):
with open(batch_path, "rb") as f:
batch = pickle.load(f)
def __getitem__(self, idx):
return copy.deepcopy(self.batch)
def __len__(self):
return 1000
class DummyDataLoader(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.dataset = Dataset()
def train_dataloader(self):
return torch.utils.data.DataLoader(self.dataset)
...@@ -2,25 +2,33 @@ import argparse ...@@ -2,25 +2,33 @@ import argparse
import logging import logging
import os import os
os.environ["CUDA_VISIBLE_DEVICES"] = "5" os.environ["CUDA_VISIBLE_DEVICES"] = "6"
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"
import random import random
import time import time
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin from pytorch_lightning.plugins.training_type import DeepSpeedPlugin
import torch import torch
from openfold.config import model_config from openfold.config import model_config
from openfold.data.data_modules import ( from openfold.data.data_modules import (
OpenFoldDataModule, OpenFoldDataModule,
DummyDataLoader,
) )
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.utils.exponential_moving_average import ExponentialMovingAverage from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss from openfold.utils.loss import AlphaFoldLoss
from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.tensor_utils import tensor_tree_map
import copy
class OpenFoldWrapper(pl.LightningModule): class OpenFoldWrapper(pl.LightningModule):
def __init__(self, config): def __init__(self, config):
...@@ -28,12 +36,17 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -28,12 +36,17 @@ class OpenFoldWrapper(pl.LightningModule):
self.config = config self.config = config
self.model = AlphaFold(config) self.model = AlphaFold(config)
self.loss = AlphaFoldLoss(config.loss) self.loss = AlphaFoldLoss(config.loss)
self.ema = ExponentialMovingAverage(self.model, decay=config.ema.decay) self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay
)
def forward(self, batch): def forward(self, batch):
return self.model(batch) return self.model(batch)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
if(self.ema.device != batch["aatype"].device):
self.ema.to(batch["aatype"].device)
# Run the model # Run the model
outputs = self(batch) outputs = self(batch)
...@@ -84,18 +97,29 @@ def main(args): ...@@ -84,18 +97,29 @@ def main(args):
low_prec=(args.precision == 16) low_prec=(args.precision == 16)
) )
model_module = OpenFoldWrapper(config)
#data_module = DummyDataLoader("batch.pickle")
data_module = OpenFoldDataModule(
config=config.data,
batch_seed=args.seed,
**vars(args)
)
data_module.prepare_data()
data_module.setup()
plugins = [] plugins = []
#plugins.append(DeepSpeedPlugin(config="deepspeed_config.json")) if(args.deepspeed_config_path is not None):
plugins.append(DeepSpeedPlugin(config=args.deepspeed_config_path))
#os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
#plugins.append(DDPPlugin(find_unused_parameters=True))
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
plugins=plugins, plugins=plugins,
) )
model_module = OpenFoldWrapper(config) trainer.fit(model_module, datamodule=data_module)
data_module = OpenFoldDataModule(config=config.data, **vars(args))
trainer.fit(model_module, data_module)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -160,6 +184,10 @@ if __name__ == "__main__": ...@@ -160,6 +184,10 @@ if __name__ == "__main__":
"--seed", type=int, default=None, "--seed", type=int, default=None,
help="Random seed" help="Random seed"
) )
parser.add_argument(
"--deepspeed_config_path", type=str, default=None,
help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
)
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
parser.set_defaults( parser.set_defaults(
...@@ -172,5 +200,6 @@ if __name__ == "__main__": ...@@ -172,5 +200,6 @@ if __name__ == "__main__":
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
random.seed(args.seed + 1) random.seed(args.seed + 1)
np.random.seed(args.seed + 2) np.random.seed(args.seed + 2)
args.seed += 1
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment