Unverified Commit ab77af82 authored by Rui Xu's avatar Rui Xu Committed by GitHub
Browse files

add set_random_seed and use rank shift feature (#373)

parent e43fe0e2
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import os
import random
import sys import sys
import time import time
from getpass import getuser from getpass import getuser
from socket import gethostname from socket import gethostname
import numpy as np
import torch
import mmcv import mmcv
...@@ -48,3 +53,29 @@ def obj_from_dict(info, parent=None, default_args=None): ...@@ -48,3 +53,29 @@ def obj_from_dict(info, parent=None, default_args=None):
for name, value in default_args.items(): for name, value in default_args.items():
args.setdefault(name, value) args.setdefault(name, value)
return obj_type(**args) return obj_type(**args)
def set_random_seed(seed, deterministic=False, use_rank_shift=False):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
rank_shift (bool): Whether to add rank number to the random seed to
have different random seed in different threads. Default: False.
"""
if use_rank_shift:
rank, _ = mmcv.runner.get_dist_info()
seed += rank
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment