Unverified Commit 519b1564 authored by jihan.yang's avatar jihan.yang Committed by GitHub
Browse files

Re-organize fix random seed code; Make saved checkpoint compatible among...

Re-organize fix random seed code; Make saved checkpoint compatible among different torch versions (#986)
parent 846cf3ed
import torch
from functools import partial
from torch.utils.data import DataLoader
from torch.utils.data import DistributedSampler as _DistributedSampler
......@@ -44,7 +45,7 @@ class DistributedSampler(_DistributedSampler):
return iter(indices)
def build_dataloader(dataset_cfg, class_names, batch_size, dist, root_path=None, workers=4,
def build_dataloader(dataset_cfg, class_names, batch_size, dist, root_path=None, workers=4, seed=None,
logger=None, training=True, merge_all_iters_to_one_epoch=False, total_epochs=0):
dataset = __all__[dataset_cfg.DATASET](
......@@ -70,7 +71,7 @@ def build_dataloader(dataset_cfg, class_names, batch_size, dist, root_path=None,
dataloader = DataLoader(
dataset, batch_size=batch_size, pin_memory=True, num_workers=workers,
shuffle=(sampler is None) and training, collate_fn=dataset.collate_batch,
drop_last=False, sampler=sampler, timeout=0
drop_last=False, sampler=sampler, timeout=0, worker_init_fn=partial(common_utils.worker_init_fn, seed=seed)
)
return dataset, dataloader, sampler
......@@ -103,10 +103,20 @@ def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def worker_init_fn(worker_id, seed=666):
if seed is not None:
random.seed(seed + worker_id)
np.random.seed(seed + worker_id)
torch.manual_seed(seed + worker_id)
torch.cuda.manual_seed(seed + worker_id)
torch.cuda.manual_seed_all(seed + worker_id)
def get_pad_params(desired_size, cur_size):
"""
Get padding parameters for np.pad function
......
......@@ -76,7 +76,7 @@ def main():
args.epochs = cfg.OPTIMIZATION.NUM_EPOCHS if args.epochs is None else args.epochs
if args.fix_random_seed:
common_utils.set_random_seed(666)
common_utils.set_random_seed(666 + cfg.LOCAL_RANK)
output_dir = cfg.ROOT_DIR / 'output' / cfg.EXP_GROUP_PATH / cfg.TAG / args.extra_tag
ckpt_dir = output_dir / 'ckpt'
......@@ -110,7 +110,8 @@ def main():
logger=logger,
training=True,
merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch,
total_epochs=args.epochs
total_epochs=args.epochs,
seed=666 if args.fix_random_seed else None
)
model = build_network(model_cfg=cfg.MODEL, num_class=len(cfg.CLASS_NAMES), dataset=train_set)
......
......@@ -166,7 +166,13 @@ def save_checkpoint(state, filename='checkpoint'):
optimizer_state = state['optimizer_state']
state.pop('optimizer_state', None)
optimizer_filename = '{}_optim.pth'.format(filename)
if torch.__version__ >= '1.4':
torch.save({'optimizer_state': optimizer_state}, optimizer_filename, _use_new_zipfile_serialization=False)
else:
torch.save({'optimizer_state': optimizer_state}, optimizer_filename)
filename = '{}.pth'.format(filename)
if torch.__version__ >= '1.4':
torch.save(state, filename, _use_new_zipfile_serialization=False)
else:
torch.save(state, filename)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment