Commit 4c06a4fd authored by Guolin Ke's avatar Guolin Ke
Browse files

refine rngstate in trainer

parent 70dcdf17
......@@ -717,7 +717,7 @@ class Trainer(object):
with torch.autograd.profiler.record_function("optimizer"):
# fixed the seed in case for the stochastic rounding in different ranks
with utils.torch_seed(self.args.seed, self.get_num_updates(), -1):
with utils.torch_seed(self.args.seed, self.get_num_updates()):
# take an optimization step
self.task.optimizer_step(
self.optimizer,
......@@ -733,8 +733,14 @@ class Trainer(object):
# out where it fails
self.zero_grad()
with NanDetector(self.get_model()):
for _, sample in enumerate(samples):
for i, sample in enumerate(samples):
sample, _ = self._prepare_sample(sample)
with utils.torch_seed(
self.args.seed,
self.get_num_updates(),
i,
self.data_parallel_rank,
):
self.task.train_step(
sample,
self.model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment