Commit 300f7157 authored by Kai Chen's avatar Kai Chen
Browse files

allow manually setting random seeds

parent 143a8372
......@@ -4,6 +4,7 @@ import argparse
import logging
from collections import OrderedDict
import numpy as np
import torch
from mmcv import Config
from mmcv.torchpack import Runner, obj_from_dict
......@@ -53,6 +54,12 @@ def get_logger(log_level):
return logger
def set_random_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
......@@ -63,6 +70,7 @@ def parse_args():
help='whether to add a validate phase')
parser.add_argument(
'--gpus', type=int, default=1, help='number of gpus to use')
parser.add_argument('--seed', type=int, help='random seed')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
......@@ -84,6 +92,11 @@ def main():
logger = get_logger(cfg.log_level)
# set random seed if specified
if args.seed is not None:
logger.info('Set random seed to {}'.format(args.seed))
set_random_seed(args.seed)
# init distributed environment if necessary
if args.launcher == 'none':
dist = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment