Unverified Commit 1ebd7ea6 authored by Jintao Lin's avatar Jintao Lin Committed by GitHub
Browse files

add unittest for set_random_seed (#376)

parent 0970ae94
......@@ -14,7 +14,7 @@ from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
DefaultOptimizerConstructor, build_optimizer,
build_optimizer_constructor)
from .priority import Priority, get_priority
from .utils import get_host_info, get_time_str, obj_from_dict
from .utils import get_host_info, get_time_str, obj_from_dict, set_random_seed
__all__ = [
'BaseRunner', 'Runner', 'EpochBasedRunner', 'IterBasedRunner', 'LogBuffer',
......@@ -27,5 +27,5 @@ __all__ = [
'obj_from_dict', 'init_dist', 'get_dist_info', 'master_only',
'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
'build_optimizer', 'build_optimizer_constructor', 'IterLoader',
'IterBasedRunner'
'IterBasedRunner', 'set_random_seed'
]
import os
import random
import numpy as np
import torch
from mmcv.runner import set_random_seed
def test_set_random_seed():
set_random_seed(0)
a_random = random.randint(0, 10)
a_np_random = np.random.rand(2, 2)
a_torch_random = torch.rand(2, 2)
assert torch.backends.cudnn.deterministic is False
assert torch.backends.cudnn.benchmark is False
assert os.environ['PYTHONHASHSEED'] == str(0)
set_random_seed(0, True)
b_random = random.randint(0, 10)
b_np_random = np.random.rand(2, 2)
b_torch_random = torch.rand(2, 2)
assert torch.backends.cudnn.deterministic is True
assert torch.backends.cudnn.benchmark is False
assert a_random == b_random
assert np.equal(a_np_random, b_np_random).all()
assert torch.equal(a_torch_random, b_torch_random)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment