Commit cfd2e719 authored by Jennifer's avatar Jennifer
Browse files

seed workers fix and validation_epoch_end extra argument

parent a317ad27
......@@ -9,6 +9,7 @@ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DeepSpeedStrategy, DDPStrategy
from pytorch_lightning import seed_everything
import torch
from openfold.config import model_config
......@@ -24,7 +25,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose
from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.validation_metrics import (
......@@ -155,7 +155,7 @@ class OpenFoldWrapper(pl.LightningModule):
self._log(loss_breakdown, batch, outputs, train=False)
def on_validation_epoch_end(self, _):
def on_validation_epoch_end(self):
# Restore the model weights to normal
self.model.load_state_dict(self.cached_weights)
self.cached_weights = None
......@@ -276,7 +276,7 @@ class OpenFoldWrapper(pl.LightningModule):
def main(args):
if (args.seed is not None):
seed_everything(args.seed)
seed_everything(args.seed, workers=True)
config = model_config(
args.config_preset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment