Commit 104cead1 authored by Myle Ott's avatar Myle Ott
Browse files

Set seed after each epoch to improve consistency when resuming

parent 8b4c45a2
...@@ -57,9 +57,6 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -57,9 +57,6 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
"""Initialize child processes.""" """Initialize child processes."""
self.args = args self.args = args
# set torch.seed in this process
torch.manual_seed(args.seed)
# set CUDA device # set CUDA device
torch.cuda.set_device(device_id) torch.cuda.set_device(device_id)
...@@ -142,6 +139,15 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -142,6 +139,15 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.lr_scheduler, cuda_device=device_id) self.lr_scheduler, cuda_device=device_id)
return extra_state return extra_state
def set_seed(self, seed):
Future.gen_list([
self.call_async(rank, '_async_set_seed', seed=seed)
for rank in range(self.num_replicas)
])
def _async_set_seed(self, rank, device_id, seed):
torch.manual_seed(seed)
def train_step(self, samples): def train_step(self, samples):
"""Do forward, backward and gradient step in parallel.""" """Do forward, backward and gradient step in parallel."""
# PyTorch initializes gradient buffers lazily, so the first # PyTorch initializes gradient buffers lazily, so the first
......
...@@ -133,6 +133,7 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus): ...@@ -133,6 +133,7 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters = collections.defaultdict(lambda: AverageMeter())
desc = '| epoch {:03d}'.format(epoch) desc = '| epoch {:03d}'.format(epoch)
trainer.set_seed(args.seed + epoch)
lr = trainer.get_lr() lr = trainer.get_lr()
with progress_bar(itr, desc, leave=False) as t: with progress_bar(itr, desc, leave=False) as t:
for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset): for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment