Unverified Commit 2eb0a10d authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Feature] Add worker_init_fn (#1788)



* add worker_init_fn

* "Fix as comment"

* Fix format
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent e8cf9613
...@@ -50,6 +50,7 @@ else: ...@@ -50,6 +50,7 @@ else:
is_rocm_pytorch) is_rocm_pytorch)
# yapf: enable # yapf: enable
from .registry import Registry, build_from_cfg from .registry import Registry, build_from_cfg
from .seed import worker_init_fn
from .trace import is_jit_tracing from .trace import is_jit_tracing
__all__ = [ __all__ = [
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger', 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
...@@ -70,5 +71,5 @@ else: ...@@ -70,5 +71,5 @@ else:
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer', 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
'assert_params_all_zeros', 'check_python_script', 'assert_params_all_zeros', 'check_python_script',
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch', 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
'_get_cuda_home', 'load_url', 'has_method' '_get_cuda_home', 'load_url', 'has_method', 'worker_init_fn'
] ]
# Copyright (c) OpenMMLab. All rights reserved.
import random
import numpy as np
import torch
def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
"""Function to initialize each worker.
The seed of each worker equals to
``num_worker * rank + worker_id + user_seed``.
Args:
worker_id (int): Id for each worker.
num_workers (int): Number of workers.
rank (int): Rank in distributed training.
seed (int): Random seed.
"""
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
torch.manual_seed(worker_seed)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment