"vscode:/vscode.git/clone" did not exist on "8c550d4c85ed15058ab7546ce1307ed875c60862"
Unverified Commit ab77af82 authored by Rui Xu's avatar Rui Xu Committed by GitHub
Browse files

add set_random_seed and use rank shift feature (#373)

parent e43fe0e2
# Copyright (c) Open-MMLab. All rights reserved.
import os
import random
import sys
import time
from getpass import getuser
from socket import gethostname
import numpy as np
import torch
import mmcv
......@@ -48,3 +53,29 @@ def obj_from_dict(info, parent=None, default_args=None):
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args)
def set_random_seed(seed, deterministic=False, use_rank_shift=False):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
rank_shift (bool): Whether to add rank number to the random seed to
have different random seed in different threads. Default: False.
"""
if use_rank_shift:
rank, _ = mmcv.runner.get_dist_info()
seed += rank
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment