Commit a56ea9b5 authored by Lukas Jarosch's avatar Lukas Jarosch
Browse files

Fix distributed seeding behavior

This adds workers=True to the Lightning seed_everything function which guarantees different random states across all processes in distributed training. Prior to that some processes on different GPUs with the same worker ID could share the same random state.

Note that this will break reproducibility between runs prior to and after this change.

Also removes the seed and supress_output modules that were not used anymore in OpenFold.
parent ef0c9fac
import os
import logging
import random
import numpy as np
from pytorch_lightning.utilities.seed import seed_everything
from openfold.utils.suppress_output import SuppressLogging
def seed_globally(seed=None):
if("PL_GLOBAL_SEED" not in os.environ):
if(seed is None):
seed = random.randint(0, np.iinfo(np.uint32).max)
os.environ["PL_GLOBAL_SEED"] = str(seed)
logging.info(f'os.environ["PL_GLOBAL_SEED"] set to {seed}')
# seed_everything is a bit log-happy
with SuppressLogging(logging.INFO):
seed_everything(seed=None)
import logging
import sys
class SuppressStdout:
def __enter__(self):
self.stdout = sys.stdout
dev_null = open("/dev/null", "w")
sys.stdout = dev_null
def __exit__(self, typ, value, traceback):
fp = sys.stdout
sys.stdout = self.stdout
fp.close()
class SuppressLogging:
def __init__(self, level):
self.level = level
def __enter__(self):
logging.disable(self.level)
def __exit__(self, typ, value, traceback):
logging.disable(logging.NOTSET)
......@@ -8,6 +8,7 @@ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
from pytorch_lightning.utilities.seed import seed_everything
import torch
from openfold.config import model_config
......@@ -23,7 +24,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose
from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.validation_metrics import (
......@@ -272,7 +272,7 @@ class OpenFoldWrapper(pl.LightningModule):
def main(args):
if(args.seed is not None):
seed_everything(args.seed)
seed_everything(args.seed, workers=True)
config = model_config(
args.config_preset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment