Commit 9660a43d authored by Lukas Jarosch's avatar Lukas Jarosch Committed by jnwei
Browse files

Fix distributed seeding behavior

This adds workers=True to the Lightning seed_everything function which guarantees different random states across all processes in distributed training. Prior to that some processes on different GPUs with the same worker ID could share the same random state.

Note that this will break reproducibility between runs prior to and after this change.

Also removes the seed and supress_output modules that were not used anymore in OpenFold.
parent cc6deaa8
import os
import logging
import random
import numpy as np
from pytorch_lightning.utilities.seed import seed_everything
from openfold.utils.suppress_output import SuppressLogging
def seed_globally(seed=None):
if("PL_GLOBAL_SEED" not in os.environ):
if(seed is None):
seed = random.randint(0, np.iinfo(np.uint32).max)
os.environ["PL_GLOBAL_SEED"] = str(seed)
logging.info(f'os.environ["PL_GLOBAL_SEED"] set to {seed}')
# seed_everything is a bit log-happy
with SuppressLogging(logging.INFO):
seed_everything(seed=None)
import logging
import sys
class SuppressStdout:
def __enter__(self):
self.stdout = sys.stdout
dev_null = open("/dev/null", "w")
sys.stdout = dev_null
def __exit__(self, typ, value, traceback):
fp = sys.stdout
sys.stdout = self.stdout
fp.close()
class SuppressLogging:
def __init__(self, level):
self.level = level
def __enter__(self):
logging.disable(self.level)
def __exit__(self, typ, value, traceback):
logging.disable(logging.NOTSET)
...@@ -8,6 +8,7 @@ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor ...@@ -8,6 +8,7 @@ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
from pytorch_lightning.utilities.seed import seed_everything
import torch import torch
from openfold.config import model_config from openfold.config import model_config
...@@ -23,7 +24,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage ...@@ -23,7 +24,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss, lddt_ca from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose from openfold.utils.superimposition import superimpose
from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.validation_metrics import ( from openfold.utils.validation_metrics import (
...@@ -273,7 +273,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -273,7 +273,7 @@ class OpenFoldWrapper(pl.LightningModule):
def main(args): def main(args):
if(args.seed is not None): if(args.seed is not None):
seed_everything(args.seed) seed_everything(args.seed, workers=True)
config = model_config( config = model_config(
args.config_preset, args.config_preset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment