Commit fcf9e2f8 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add dummy dataloader, most recent version of training script

parent 6dc8aa7f
......@@ -13,7 +13,8 @@ cases where the *Nature* paper differs from the source, we always defer to the
latter.
OpenFold is built to support inference with AlphaFold's original JAX weights.
Try it out with our [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb).
Try it out with our [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb)
(not yet visible from Colab because the repo is still private).
Unlike DeepMind's public code, OpenFold is also trainable. It can be trained
with or without [DeepSpeed](https://github.com/microsoft/deepspeed) and with
......
......@@ -2,6 +2,7 @@ from functools import partial
import json
import logging
import os
import pickle
from typing import Optional, Sequence
import ml_collections as mlc
......@@ -446,3 +447,24 @@ class OpenFoldDataModule(pl.LightningDataModule):
num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=self._gen_batch_collator("predict")
)
class DummyDataset(torch.utils.data.Dataset):
def __init__(self, batch_path):
with open(batch_path, "rb") as f:
batch = pickle.load(f)
def __getitem__(self, idx):
return copy.deepcopy(self.batch)
def __len__(self):
return 1000
class DummyDataLoader(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.dataset = Dataset()
def train_dataloader(self):
return torch.utils.data.DataLoader(self.dataset)
......@@ -2,25 +2,33 @@ import argparse
import logging
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"
import random
import time
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin
import torch
from openfold.config import model_config
from openfold.data.data_modules import (
OpenFoldDataModule,
DummyDataLoader,
)
from openfold.model.model import AlphaFold
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss
from openfold.utils.tensor_utils import tensor_tree_map
import copy
class OpenFoldWrapper(pl.LightningModule):
def __init__(self, config):
......@@ -28,12 +36,17 @@ class OpenFoldWrapper(pl.LightningModule):
self.config = config
self.model = AlphaFold(config)
self.loss = AlphaFoldLoss(config.loss)
self.ema = ExponentialMovingAverage(self.model, decay=config.ema.decay)
self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay
)
def forward(self, batch):
return self.model(batch)
def training_step(self, batch, batch_idx):
if(self.ema.device != batch["aatype"].device):
self.ema.to(batch["aatype"].device)
# Run the model
outputs = self(batch)
......@@ -82,20 +95,31 @@ def main(args):
"model_1",
train=True,
low_prec=(args.precision == 16)
)
)
model_module = OpenFoldWrapper(config)
#data_module = DummyDataLoader("batch.pickle")
data_module = OpenFoldDataModule(
config=config.data,
batch_seed=args.seed,
**vars(args)
)
data_module.prepare_data()
data_module.setup()
plugins = []
#plugins.append(DeepSpeedPlugin(config="deepspeed_config.json"))
if(args.deepspeed_config_path is not None):
plugins.append(DeepSpeedPlugin(config=args.deepspeed_config_path))
#os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
#plugins.append(DDPPlugin(find_unused_parameters=True))
trainer = pl.Trainer.from_argparse_args(
args,
plugins=plugins,
)
model_module = OpenFoldWrapper(config)
data_module = OpenFoldDataModule(config=config.data, **vars(args))
trainer.fit(model_module, data_module)
trainer.fit(model_module, datamodule=data_module)
if __name__ == "__main__":
......@@ -160,6 +184,10 @@ if __name__ == "__main__":
"--seed", type=int, default=None,
help="Random seed"
)
parser.add_argument(
"--deepspeed_config_path", type=str, default=None,
help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
)
parser = pl.Trainer.add_argparse_args(parser)
parser.set_defaults(
......@@ -172,5 +200,6 @@ if __name__ == "__main__":
torch.manual_seed(args.seed)
random.seed(args.seed + 1)
np.random.seed(args.seed + 2)
args.seed += 1
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment