Commit cfd2e719 authored by Jennifer's avatar Jennifer
Browse files

seed workers fix and validation_epoch_end extra argument

parent a317ad27
...@@ -9,6 +9,7 @@ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor ...@@ -9,6 +9,7 @@ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DeepSpeedStrategy, DDPStrategy from pytorch_lightning.strategies import DeepSpeedStrategy, DDPStrategy
from pytorch_lightning import seed_everything
import torch import torch
from openfold.config import model_config from openfold.config import model_config
...@@ -24,7 +25,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage ...@@ -24,7 +25,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss, lddt_ca from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose from openfold.utils.superimposition import superimpose
from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.validation_metrics import ( from openfold.utils.validation_metrics import (
...@@ -155,7 +155,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -155,7 +155,7 @@ class OpenFoldWrapper(pl.LightningModule):
self._log(loss_breakdown, batch, outputs, train=False) self._log(loss_breakdown, batch, outputs, train=False)
def on_validation_epoch_end(self, _): def on_validation_epoch_end(self):
# Restore the model weights to normal # Restore the model weights to normal
self.model.load_state_dict(self.cached_weights) self.model.load_state_dict(self.cached_weights)
self.cached_weights = None self.cached_weights = None
...@@ -276,7 +276,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -276,7 +276,7 @@ class OpenFoldWrapper(pl.LightningModule):
def main(args): def main(args):
if (args.seed is not None): if (args.seed is not None):
seed_everything(args.seed) seed_everything(args.seed, workers=True)
config = model_config( config = model_config(
args.config_preset, args.config_preset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment