Commit 1715e3d5 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Improve global seeding, clean up training script a little

parent 971c41d2
import os
import logging
import random
import numpy as np
from pytorch_lightning.utilities.seed import seed_everything
from openfold.utils.suppress_output import SuppressLogging
def seed_globally(seed=None):
if("PL_GLOBAL_SEED" not in os.environ):
if(seed is None):
seed = random.randint(0, np.iinfo(np.uint32).max)
os.environ["PL_GLOBAL_SEED"] = str(seed)
print(f'os.environ["PL_GLOBAL_SEED"] set to {seed}')
# seed_everything is a bit log-happy
with SuppressLogging(logging.INFO):
seed_everything(seed=None)
import logging
import sys
class SuppressStdout:
def __enter__(self):
self.stdout = sys.stdout
dev_null = open("/dev/null", "w")
sys.stdout = dev_null
def __exit__(self, typ, value, traceback):
fp = sys.stdout
sys.stdout = self.stdout
fp.close()
class SuppressLogging:
def __init__(self, level):
self.level = level
def __enter__(self):
logging.disable(self.level)
def __exit__(self, typ, value, traceback):
logging.disable(logging.NOTSET)
......@@ -2,7 +2,7 @@ import argparse
import logging
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
#os.environ["CUDA_VISIBLE_DEVICES"] = "6"
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"
......@@ -25,10 +25,9 @@ from openfold.data.data_modules import (
from openfold.model.model import AlphaFold
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss
from openfold.utils.seed import seed_everything
from openfold.utils.tensor_utils import tensor_tree_map
import copy
class OpenFoldWrapper(pl.LightningModule):
def __init__(self, config):
......@@ -91,6 +90,9 @@ class OpenFoldWrapper(pl.LightningModule):
def main(args):
if(args.seed is not None):
seed_everything(args.seed)
config = model_config(
"model_1",
train=True,
......@@ -111,9 +113,6 @@ def main(args):
if(args.deepspeed_config_path is not None):
plugins.append(DeepSpeedPlugin(config=args.deepspeed_config_path))
#os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
#plugins.append(DDPPlugin(find_unused_parameters=True))
trainer = pl.Trainer.from_argparse_args(
args,
plugins=plugins,
......@@ -196,10 +195,4 @@ if __name__ == "__main__":
args = parser.parse_args()
if(args.seed is not None):
torch.manual_seed(args.seed)
random.seed(args.seed + 1)
np.random.seed(args.seed + 2)
args.seed += 1
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment