Unverified Commit 0188361b authored by double-vin's avatar double-vin Committed by GitHub
Browse files

add seed_torch function (#147)


Co-authored-by: default avatarzhuww <zhuww@sugon.com>
parent 71735d5d
...@@ -47,6 +47,17 @@ from fastfold.utils.tensor_utils import tensor_tree_map ...@@ -47,6 +47,17 @@ from fastfold.utils.tensor_utils import tensor_tree_map
if int(torch.__version__.split(".")[0]) >= 1 and int(torch.__version__.split(".")[1]) > 11: if int(torch.__version__.split(".")[0]) >= 1 and int(torch.__version__.split(".")[1]) > 11:
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
def seed_torch(seed=1029):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
@contextlib.contextmanager @contextlib.contextmanager
def temp_fasta_file(fasta_str: str): def temp_fasta_file(fasta_str: str):
with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file: with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file:
...@@ -215,9 +226,11 @@ def inference_multimer_model(args): ...@@ -215,9 +226,11 @@ def inference_multimer_model(args):
) )
output_dir_base = args.output_dir output_dir_base = args.output_dir
random_seed = args.data_random_seed random_seed = args.data_random_seed
if random_seed is None: if random_seed is None:
random_seed = random.randrange(sys.maxsize) random_seed = random.randrange(sys.maxsize)
# seed_torch(seed=1029)
feature_processor = feature_pipeline.FeaturePipeline( feature_processor = feature_pipeline.FeaturePipeline(
config.data config.data
...@@ -347,9 +360,12 @@ def inference_monomer_model(args): ...@@ -347,9 +360,12 @@ def inference_monomer_model(args):
data_processor = data_pipeline.DataPipeline(template_featurizer=template_featurizer,) data_processor = data_pipeline.DataPipeline(template_featurizer=template_featurizer,)
output_dir_base = args.output_dir output_dir_base = args.output_dir
random_seed = args.data_random_seed random_seed = args.data_random_seed
if random_seed is None: if random_seed is None:
random_seed = random.randrange(sys.maxsize) random_seed = random.randrange(sys.maxsize)
# seed_torch(seed=1029)
feature_processor = feature_pipeline.FeaturePipeline(config.data) feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base): if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base) os.makedirs(output_dir_base)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment