Unverified Commit 6b73a226 authored by Wenhao Wu's avatar Wenhao Wu Committed by GitHub
Browse files

[Enchance] Set a random seed when the user does not set a seed. (#1072)

parent 81764388
......@@ -4,10 +4,11 @@ from .inference import (convert_SyncBN, inference_detector,
inference_multi_modality_detector, inference_segmentor,
init_model, show_result_meshlab)
from .test import single_gpu_test
from .train import train_model
from .train import init_random_seed, train_model
__all__ = [
'inference_detector', 'init_model', 'single_gpu_test',
'inference_mono_3d_detector', 'show_result_meshlab', 'convert_SyncBN',
'train_model', 'inference_multi_modality_detector', 'inference_segmentor'
'train_model', 'inference_multi_modality_detector', 'inference_segmentor',
'init_random_seed'
]
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmcv.runner import get_dist_info
from torch import distributed as dist
from mmdet.apis import train_detector
from mmseg.apis import train_segmentor
def init_random_seed(seed=None, device='cuda'):
"""Initialize random seed.
If the seed is not set, the seed will be automatically randomized,
and then broadcast to all processes to prevent some potential bugs.
Args:
seed (int, optional): The seed. Default to None.
device (str, optional): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is not None:
return seed
# Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339
rank, world_size = get_dist_info()
seed = np.random.randint(2**31)
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()
def train_model(model,
dataset,
cfg,
......
......@@ -14,7 +14,7 @@ from os import path as osp
from mmdet import __version__ as mmdet_version
from mmdet3d import __version__ as mmdet3d_version
from mmdet3d.apis import train_model
from mmdet3d.apis import init_random_seed, train_model
from mmdet3d.datasets import build_dataset
from mmdet3d.models import build_model
from mmdet3d.utils import collect_env, get_root_logger
......@@ -169,12 +169,12 @@ def main():
logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, '
seed = init_random_seed(args.seed)
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
meta['seed'] = args.seed
set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = seed
meta['seed'] = seed
meta['exp_name'] = osp.basename(args.config)
model = build_model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment