Commit 4b354151 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update train_openfold.py to accomodate training multimer

parent d886a7be
......@@ -16,18 +16,18 @@ import torch
from openfold.config import model_config
from openfold.data.data_modules import (
OpenFoldDataModule,
OpenFoldDataModule,OpenFoldMultimerDataModule,
DummyDataLoader,
)
from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants
from openfold.utils.argparse import remove_arguments
from openfold.utils.argparse_utils import remove_arguments
from openfold.utils.callbacks import (
EarlyStoppingVerbose,
)
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.loss import AlphaFoldLoss, AlphaFoldMultimerLoss,lddt_ca
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose
......@@ -257,6 +257,69 @@ class OpenFoldWrapper(pl.LightningModule):
)
class OpenFoldMultimerWrapper(OpenFoldWrapper):
def __init__(self, config):
super(OpenFoldMultimerWrapper, self).__init__(config)
self.config = config
self.config.loss.masked_msa.num_classes = 22 # somehow need overwrite this part in multimer loss config
self.config.model.evoformer_stack.no_blocks = 4 # no need to go overboard here
self.config.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
self.model = AlphaFold(config)
self.loss = AlphaFoldMultimerLoss(config.loss)
self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay
)
self.cached_weights = None
self.last_lr_step = -1
def forward(self, batch):
return self.model(batch)
def training_step(self, batch, batch_idx):
all_chain_features,ground_truth = batch
if(self.ema.device != all_chain_features["aatype"].device):
self.ema.to(all_chain_features["aatype"].device)
# Run the model
outputs = self(all_chain_features)
# Compute loss
loss = self.loss(
outputs, (all_chain_features,ground_truth), _return_breakdown=False
)
# Log it
self._log(loss, all_chain_features, outputs)
return loss
def validation_step(self, batch, batch_idx):
all_chain_features,ground_truth = batch
# At the start of validation, load the EMA weights
if(self.cached_weights is None):
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param = lambda t: t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"])
# Run the model
outputs = self(all_chain_features)
# Compute loss and other metrics
all_chain_features["use_clamped_fape"] = 0.
_, loss_breakdown = self.loss(
outputs, all_chain_features, _return_breakdown=True
)
self._log(loss_breakdown, all_chain_features, outputs, train=False)
def validation_epoch_end(self, _):
# Restore the model weights to normal
self.model.load_state_dict(self.cached_weights)
self.cached_weights = None
def main(args):
if(args.seed is not None):
seed_everything(args.seed)
......@@ -266,8 +329,11 @@ def main(args):
train=True,
low_prec=(str(args.precision) == "16")
)
model_module = OpenFoldWrapper(config)
if "multimer" in args.config_preset:
print("training multimer models now")
model_module = OpenFoldMultimerWrapper(config)
else:
model_module = OpenFoldWrapper(config)
if(args.resume_from_ckpt):
if(os.path.isdir(args.resume_from_ckpt)):
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
......@@ -293,11 +359,19 @@ def main(args):
script_preset_(model_module)
#data_module = DummyDataLoader("new_batch.pickle")
data_module = OpenFoldDataModule(
if "multimer" in args.config_preset:
print("use multimer datamodule now")
data_module = OpenFoldMultimerDataModule(
config=config.data,
batch_seed=args.seed,
**vars(args)
)
else:
data_module = OpenFoldDataModule(
config=config.data,
batch_seed=args.seed,
**vars(args)
)
data_module.prepare_data()
data_module.setup()
......@@ -417,6 +491,10 @@ if __name__ == "__main__":
help='''Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target'''
)
parser.add_argument(
"--train_mmcif_data_cache_path", type=str, default=None,
help="path to the json file which records all the information of mmcif structures used during training"
)
parser.add_argument(
"--distillation_data_dir", type=str, default=None,
help="Directory containing training PDB files"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment