Unverified Commit 127f1e70 authored by Jennifer Wei's avatar Jennifer Wei Committed by GitHub
Browse files

Merge pull request #418 from aqlaboratory/seeding-fix

Fix distributed seeding behavior
parents ef0c9fac a56ea9b5
import os
import logging
import random
import numpy as np
from pytorch_lightning.utilities.seed import seed_everything
from openfold.utils.suppress_output import SuppressLogging
def seed_globally(seed=None):
if("PL_GLOBAL_SEED" not in os.environ):
if(seed is None):
seed = random.randint(0, np.iinfo(np.uint32).max)
os.environ["PL_GLOBAL_SEED"] = str(seed)
logging.info(f'os.environ["PL_GLOBAL_SEED"] set to {seed}')
# seed_everything is a bit log-happy
with SuppressLogging(logging.INFO):
seed_everything(seed=None)
import logging
import sys
class SuppressStdout:
def __enter__(self):
self.stdout = sys.stdout
dev_null = open("/dev/null", "w")
sys.stdout = dev_null
def __exit__(self, typ, value, traceback):
fp = sys.stdout
sys.stdout = self.stdout
fp.close()
class SuppressLogging:
def __init__(self, level):
self.level = level
def __enter__(self):
logging.disable(self.level)
def __exit__(self, typ, value, traceback):
logging.disable(logging.NOTSET)
...@@ -8,6 +8,7 @@ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor ...@@ -8,6 +8,7 @@ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
from pytorch_lightning.utilities.seed import seed_everything
import torch import torch
from openfold.config import model_config from openfold.config import model_config
...@@ -23,7 +24,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage ...@@ -23,7 +24,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss, lddt_ca from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose from openfold.utils.superimposition import superimpose
from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.validation_metrics import ( from openfold.utils.validation_metrics import (
...@@ -272,7 +272,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -272,7 +272,7 @@ class OpenFoldWrapper(pl.LightningModule):
def main(args): def main(args):
if(args.seed is not None): if(args.seed is not None):
seed_everything(args.seed) seed_everything(args.seed, workers=True)
config = model_config( config = model_config(
args.config_preset, args.config_preset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment